From aae998a35c2cf695e47eb9922c6b82b7b2e9c4fb Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 23 Sep 2024 16:40:23 +0200 Subject: [PATCH] [ENHANCEMENT] `argilla-server`: List records endpoint using db (#5170) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR rewrites the current list dataset records endpoint to use the DB instead of the search engine since no filtering is applied to the endpoint. ~~The PR introduces a new abstraction layer to manage DB internals: repository. With this layer, we have all db methods related to a resource in a single place, which helps to maintainability and reusability.~~ DB Query details have been moved to the Record model class, simplifying the context flows. **Type of change** - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: José Francisco Calvo Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Damián Pumar Co-authored-by: José Francisco Calvo --- .../api/handlers/v1/datasets/records.py | 75 +++++------------- .../src/argilla_server/contexts/records.py | 77 +++++++++++++++++-- .../src/argilla_server/models/database.py | 13 +++- .../handlers/v1/test_list_dataset_records.py | 65 +++++++++++++--- argilla-server/tests/unit/conftest.py | 3 +- 5 files changed, 157 insertions(+), 76 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index d415dbbcd8..5c394a77a8 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Query, Security, status @@ -43,9 +43,9 @@ SearchSuggestionsOptions, SuggestionFilterScope, ) -from argilla_server.contexts import datasets, search +from argilla_server.contexts import datasets, search, records from argilla_server.database import get_async_db -from argilla_server.enums import RecordSortField, ResponseStatusFilter +from argilla_server.enums import RecordSortField from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE from argilla_server.models import Dataset, Field, Record, User, VectorSettings @@ -53,7 +53,6 @@ AndFilter, SearchEngine, SearchResponses, - UserResponseStatusFilter, get_search_engine, ) from argilla_server.security import auth @@ -72,42 +71,13 @@ router = APIRouter() -async def _filter_records_using_search_engine( - db: "AsyncSession", - search_engine: "SearchEngine", - dataset: Dataset, - limit: int, - offset: int, - user: Optional[User] = None, - include: Optional[RecordIncludeParam] = None, -) -> Tuple[List[Record], int]: - search_responses = await _get_search_responses( - db=db, - search_engine=search_engine, - dataset=dataset, - limit=limit, - offset=offset, - user=user, - ) - - record_ids = [response.record_id for response in search_responses.items] - user_id = user.id if user else None - - return ( - await datasets.get_records_by_ids( - db=db, dataset_id=dataset.id, user_id=user_id, records_ids=record_ids, include=include - ), - search_responses.total, - ) - - def _to_search_engine_filter_scope(scope: FilterScope, user: Optional[User]) -> search_engine.FilterScope: if isinstance(scope, RecordFilterScope): return search_engine.RecordFilterScope(property=scope.property) elif isinstance(scope, MetadataFilterScope): return search_engine.MetadataFilterScope(metadata_property=scope.metadata_property) elif isinstance(scope, SuggestionFilterScope): - return search_engine.SuggestionFilterScope(question=scope.question, property=scope.property) + return search_engine.SuggestionFilterScope(question=scope.question, property=str(scope.property)) elif isinstance(scope, ResponseFilterScope): return search_engine.ResponseFilterScope(question=scope.question, property=scope.property, user=user) else: @@ -223,18 +193,6 @@ async def _get_search_responses( return await search_engine.search(**search_params) -async def _build_response_status_filter_for_search( - response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None -) -> Optional[UserResponseStatusFilter]: - user_response_status_filter = None - - if response_statuses: - # TODO(@frascuchon): user response and status responses should be split into different filter types - user_response_status_filter = UserResponseStatusFilter(user=user, statuses=response_statuses) - - return user_response_status_filter - - async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset: Dataset): try: await search.validate_search_records_query(db, query, dataset) @@ -246,7 +204,6 @@ async def _validate_search_records_query(db: "AsyncSession", query: SearchRecord async def list_dataset_records( *, db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = 0, @@ -254,19 +211,27 @@ async def list_dataset_records( current_user: User = Security(auth.get_current_user), ): dataset = await Dataset.get_or_raise(db, dataset_id) - await authorize(current_user, DatasetPolicy.list_records_with_all_responses(dataset)) - records, total = await _filter_records_using_search_engine( - db, - search_engine, - dataset=dataset, - limit=limit, + include_args = ( + dict( + with_responses=include.with_responses, + with_suggestions=include.with_suggestions, + with_vectors=include.with_all_vectors or include.vectors, + ) + if include + else {} + ) + + dataset_records, total = await records.list_dataset_records( + db=db, + dataset_id=dataset.id, offset=offset, - include=include, + limit=limit, + **include_args, ) - return Records(items=records, total=total) + return Records(items=dataset_records, total=total) @router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT) diff --git a/argilla-server/src/argilla_server/contexts/records.py b/argilla-server/src/argilla_server/contexts/records.py index c2b0f20bb9..0764b3152f 100644 --- a/argilla-server/src/argilla_server/contexts/records.py +++ b/argilla-server/src/argilla_server/contexts/records.py @@ -12,21 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Sequence +from typing import Dict, Sequence, Union, List, Tuple, Optional from uuid import UUID -from sqlalchemy import select +from sqlalchemy import select, and_, func, Select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import selectinload, contains_eager -from argilla_server.models import Dataset, Record +from argilla_server.database import get_async_db +from argilla_server.models import Dataset, Record, VectorSettings, Vector + + +async def list_dataset_records( + db: AsyncSession, + dataset_id: UUID, + offset: int, + limit: int, + with_responses: bool = False, + with_suggestions: bool = False, + with_vectors: Union[bool, List[str]] = False, +) -> Tuple[Sequence[Record], int]: + query = _record_by_dataset_id_query( + dataset_id=dataset_id, + offset=offset, + limit=limit, + with_responses=with_responses, + with_suggestions=with_suggestions, + with_vectors=with_vectors, + ) + + records = (await db.scalars(query)).unique().all() + total = await db.scalar(select(func.count(Record.id)).filter_by(dataset_id=dataset_id)) + + return records, total async def list_dataset_records_by_ids( db: AsyncSession, dataset_id: UUID, record_ids: Sequence[UUID] ) -> Sequence[Record]: - query = select(Record).filter(Record.id.in_(record_ids), Record.dataset_id == dataset_id) - return (await db.execute(query)).unique().scalars().all() + query = select(Record).where(and_(Record.id.in_(record_ids), Record.dataset_id == dataset_id)) + return (await db.scalars(query)).unique().all() async def list_dataset_records_by_external_ids( @@ -34,10 +59,11 @@ async def list_dataset_records_by_external_ids( ) -> Sequence[Record]: query = ( select(Record) - .filter(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id) + .where(and_(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id)) .options(selectinload(Record.dataset)) ) - return (await db.execute(query)).unique().scalars().all() + + return (await db.scalars(query)).unique().all() async def fetch_records_by_ids_as_dict( @@ -52,3 +78,38 @@ async def fetch_records_by_external_ids_as_dict( ) -> Dict[str, Record]: records_by_external_ids = await list_dataset_records_by_external_ids(db, dataset.id, external_ids) return {record.external_id: record for record in records_by_external_ids} + + +def _record_by_dataset_id_query( + dataset_id, + offset: Optional[int] = None, + limit: Optional[int] = None, + with_responses: bool = False, + with_suggestions: bool = False, + with_vectors: Union[bool, List[str]] = False, +) -> Select: + query = select(Record).filter_by(dataset_id=dataset_id) + + if with_responses: + query = query.options(selectinload(Record.responses)) + + if with_suggestions: + query = query.options(selectinload(Record.suggestions)) + + if with_vectors is True: + query = query.options(selectinload(Record.vectors)) + elif isinstance(with_vectors, list): + subquery = select(VectorSettings.id).filter( + and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors)) + ) + query = query.outerjoin( + Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery)) + ).options(contains_eager(Record.vectors)) + + if offset is not None: + query = query.offset(offset) + + if limit is not None: + query = query.limit(limit) + + return query.order_by(Record.inserted_at) diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index f86d90f3f7..96f3b70897 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -17,12 +17,20 @@ from typing import Any, List, Optional, Union from uuid import UUID -from sqlalchemy import JSON, ForeignKey, String, Text, UniqueConstraint, and_, sql, select, func, text from sqlalchemy import Enum as SAEnum +from sqlalchemy import ( + JSON, + ForeignKey, + String, + Text, + UniqueConstraint, + and_, + sql, +) from sqlalchemy.engine.default import DefaultExecutionContext from sqlalchemy.ext.asyncio import async_object_session from sqlalchemy.ext.mutable import MutableDict, MutableList -from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property +from sqlalchemy.orm import Mapped, mapped_column, relationship from argilla_server.api.schemas.v1.questions import QuestionSettings from argilla_server.enums import ( @@ -30,7 +38,6 @@ FieldType, MetadataPropertyType, QuestionType, - RecordStatus, ResponseStatus, SuggestionType, UserRole, diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index 4f989e5399..3dc3546d29 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -38,7 +38,6 @@ @pytest.mark.asyncio class TestSuiteListDatasetRecords: - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create() record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) @@ -58,25 +57,31 @@ async def test_list_dataset_records(self, async_client: "AsyncClient", owner_aut "items": [ { "id": str(record_a.id), + "dataset_id": str(dataset.id), "fields": {"record_a": "value_a"}, "metadata": None, "external_id": record_a.external_id, + "status": "pending", "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, { "id": str(record_b.id), + "dataset_id": str(dataset.id), "fields": {"record_b": "value_b"}, "metadata": {"unit": "test"}, "external_id": record_b.external_id, + "status": "pending", "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, { "id": str(record_c.id), + "dataset_id": str(dataset.id), "fields": {"record_c": "value_c"}, "metadata": None, "external_id": record_c.external_id, + "status": "pending", "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, @@ -188,7 +193,6 @@ async def test_list_dataset_records_with_include( assert response.status_code == 200 - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_include_vectors( self, async_client: "AsyncClient", owner_auth_header: dict ): @@ -214,6 +218,7 @@ async def test_list_dataset_records_with_include_vectors( "items": [ { "id": str(record_a.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -221,26 +226,31 @@ async def test_list_dataset_records_with_include_vectors( "vector-a": [1.0, 2.0, 3.0], "vector-b": [4.0, 5.0], }, + "status": "pending", "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, { "id": str(record_b.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, "vectors": { "vector-b": [1.0, 2.0], }, + "status": "pending", "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, { "id": str(record_c.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, "vectors": {}, + "status": "pending", "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, @@ -248,7 +258,6 @@ async def test_list_dataset_records_with_include_vectors( "total": 3, } - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_include_specific_vectors( self, async_client: "AsyncClient", owner_auth_header: dict ): @@ -278,6 +287,7 @@ async def test_list_dataset_records_with_include_specific_vectors( "items": [ { "id": str(record_a.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -285,26 +295,31 @@ async def test_list_dataset_records_with_include_specific_vectors( "vector-a": [1.0, 2.0, 3.0], "vector-b": [4.0, 5.0], }, + "status": "pending", "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, { "id": str(record_b.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, "vectors": { "vector-b": [1.0, 2.0], }, + "status": "pending", "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, { "id": str(record_c.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, "vectors": {}, + "status": "pending", "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, @@ -312,7 +327,6 @@ async def test_list_dataset_records_with_include_specific_vectors( "total": 3, } - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_offset(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create() await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) @@ -331,7 +345,6 @@ async def test_list_dataset_records_with_offset(self, async_client: "AsyncClient response_body = response.json() assert [item["id"] for item in response_body["items"]] == [str(record_c.id)] - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_limit(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create() record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) @@ -350,7 +363,6 @@ async def test_list_dataset_records_with_limit(self, async_client: "AsyncClient" response_body = response.json() assert [item["id"] for item in response_body["items"]] == [str(record_a.id)] - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_offset_and_limit( self, async_client: "AsyncClient", owner_auth_header: dict ): @@ -457,9 +469,9 @@ async def test_list_dataset_records_as_admin(self, async_client: "AsyncClient"): admin = await AdminFactory.create(workspaces=[workspace]) dataset = await DatasetFactory.create(workspace=workspace) - await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) + record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) + record_b = await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) + record_c = await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) other_dataset = await DatasetFactory.create() await RecordFactory.create_batch(size=2, dataset=other_dataset) @@ -468,6 +480,41 @@ async def test_list_dataset_records_as_admin(self, async_client: "AsyncClient"): f"/api/v1/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: admin.api_key} ) assert response.status_code == 200 + assert response.json() == { + "total": 3, + "items": [ + { + "id": str(record_a.id), + "dataset_id": str(dataset.id), + "fields": {"record_a": "value_a"}, + "metadata": None, + "external_id": record_a.external_id, + "status": "pending", + "inserted_at": record_a.inserted_at.isoformat(), + "updated_at": record_a.updated_at.isoformat(), + }, + { + "id": str(record_b.id), + "dataset_id": str(dataset.id), + "fields": {"record_b": "value_b"}, + "metadata": None, + "external_id": record_b.external_id, + "status": "pending", + "inserted_at": record_b.inserted_at.isoformat(), + "updated_at": record_b.updated_at.isoformat(), + }, + { + "id": str(record_c.id), + "dataset_id": str(dataset.id), + "fields": {"record_c": "value_c"}, + "metadata": None, + "external_id": record_c.external_id, + "status": "pending", + "inserted_at": record_c.inserted_at.isoformat(), + "updated_at": record_c.updated_at.isoformat(), + }, + ], + } async def test_list_dataset_records_as_annotator(self, async_client: "AsyncClient"): workspace = await WorkspaceFactory.create() diff --git a/argilla-server/tests/unit/conftest.py b/argilla-server/tests/unit/conftest.py index a702be6c36..4af66e9fb2 100644 --- a/argilla-server/tests/unit/conftest.py +++ b/argilla-server/tests/unit/conftest.py @@ -22,7 +22,7 @@ from opensearchpy import OpenSearch from argilla_server import telemetry -from argilla_server.contexts import distribution, datasets +from argilla_server.contexts import distribution, datasets, records from argilla_server.api.routes import api_v1 from argilla_server.constants import API_KEY_HEADER_NAME, DEFAULT_API_KEY from argilla_server.database import get_async_db @@ -91,6 +91,7 @@ async def override_get_search_engine(): mocker.patch.object(distribution, "_get_async_db", override_get_async_db) mocker.patch.object(datasets, "get_async_db", override_get_async_db) + mocker.patch.object(records, "get_async_db", override_get_async_db) api_v1.dependency_overrides.update( {