From 432f5578a8376da35849fa21ae4ff803b8bc1a29 Mon Sep 17 00:00:00 2001 From: bharath97-git <70206310+bharath97-git@users.noreply.github.com> Date: Mon, 28 Oct 2024 19:47:43 +0530 Subject: [PATCH 1/2] feat: added support for `with_vectors` with query filter in sdk (#5638) # Description Closes #5636 **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** Passed `with_vectors` parameter to the search method while retrieving records and I could see the vectors in the response now. - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: bharath Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- argilla/CHANGELOG.md | 1 + argilla/src/argilla/_api/_records.py | 4 +++- argilla/src/argilla/records/_dataset_records.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/argilla/CHANGELOG.md b/argilla/CHANGELOG.md index e7b8bf6b7b..2136f769ed 100644 --- a/argilla/CHANGELOG.md +++ b/argilla/CHANGELOG.md @@ -19,6 +19,7 @@ These are the section headers that we use: ### Changed - Terms metadata properties accept other values than `str`. ([#5594](https://github.com/argilla-io/argilla/pull/5594)) +- Added support for `with_vectors` while fetching records along with a search query. ([#5638](https://github.com/argilla-io/argilla/pull/5638)) ## [2.3.0](https://github.com/argilla-io/argilla/compare/v2.2.2...v2.3.0) diff --git a/argilla/src/argilla/_api/_records.py b/argilla/src/argilla/_api/_records.py index 3c5f13270b..6581655839 100644 --- a/argilla/src/argilla/_api/_records.py +++ b/argilla/src/argilla/_api/_records.py @@ -111,13 +111,15 @@ def search( limit: int = 100, with_suggestions: bool = True, with_responses: bool = True, - # TODO: Add support for `with_vectors` + with_vectors: Optional[Union[List, bool]] = None, ) -> Tuple[List[Tuple[RecordModel, float]], int]: include = [] if with_suggestions: include.append("suggestions") if with_responses: include.append("responses") + if with_vectors: + include.append(self._represent_vectors_to_include(with_vectors)) params = { "offset": offset, "limit": limit, diff --git a/argilla/src/argilla/records/_dataset_records.py b/argilla/src/argilla/records/_dataset_records.py index 2887fd69d2..eb9b21afa3 100644 --- a/argilla/src/argilla/records/_dataset_records.py +++ b/argilla/src/argilla/records/_dataset_records.py @@ -132,6 +132,7 @@ def _fetch_from_server_with_search(self) -> List[RecordModel]: offset=self.__offset, with_responses=self.__with_responses, with_suggestions=self.__with_suggestions, + with_vectors=self.__with_vectors, ) return [record_model for record_model, _ in search_items] From 9bc6ff3a5869797e7218fa0b2e2b0519613dcd7a Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 28 Oct 2024 16:39:47 +0100 Subject: [PATCH 2/2] perf: Using search engine to compute the total number of records for user metrics (#5641) # Description Using search engine total instead of DB count to get the total number of records. This improves the performance when running the Argilla server on an HF space with persistent storage enabled. **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../api/handlers/v1/datasets/datasets.py | 2 +- .../src/argilla_server/contexts/datasets.py | 3 +-- .../tests/unit/api/handlers/v1/test_datasets.py | 12 +++++++++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py index 13c5e45e97..5d72063ed5 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py @@ -152,7 +152,7 @@ async def get_current_user_dataset_metrics( await authorize(current_user, DatasetPolicy.get(dataset)) - result = await datasets.get_user_dataset_metrics(db, search_engine, current_user, dataset) + result = await datasets.get_user_dataset_metrics(search_engine, current_user, dataset) return DatasetMetrics(responses=result) diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 5de43865d1..950cd9fb43 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -363,12 +363,11 @@ async def _configure_query_relationships( async def get_user_dataset_metrics( - db: AsyncSession, search_engine: SearchEngine, user: User, dataset: Dataset, ) -> dict: - total_records = await Record.count_by(db, dataset_id=dataset.id) + total_records = (await get_dataset_progress(search_engine, dataset))["total"] result = await search_engine.get_dataset_user_progress(dataset, user) submitted_responses = result.get("submitted", 0) diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 99020470b7..2919227f96 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -742,6 +742,8 @@ async def test_get_current_user_dataset_metrics( dataset = await DatasetFactory.create() records = await RecordFactory.create_batch(size=8, dataset=dataset) + mock_search_engine.get_dataset_progress.return_value = {"total": len(records)} + mock_search_engine.get_dataset_user_progress.return_value = { "total": 6, "submitted": 3, @@ -772,6 +774,7 @@ async def test_get_current_user_dataset_metrics_with_empty_dataset( ): dataset = await DatasetFactory.create() + mock_search_engine.get_dataset_progress.return_value = {} mock_search_engine.get_dataset_user_progress.return_value = {} response = await async_client.get( @@ -791,7 +794,7 @@ async def test_get_current_user_dataset_metrics_with_empty_dataset( } @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) - async def test_get_current_user_dataset_metrics_as_annotator( + async def test_get_current_user_dataset_metrics_as_different_role( self, async_client: "AsyncClient", mock_search_engine: SearchEngine, @@ -800,10 +803,13 @@ async def test_get_current_user_dataset_metrics_as_annotator( dataset = await DatasetFactory.create() records = await RecordFactory.create_batch(size=6, dataset=dataset) - user = await AnnotatorFactory.create(workspaces=[dataset.workspace], role=role) + user = await UserFactory.create(workspaces=[dataset.workspace], role=role) - mock_search_engine.get_dataset_user_progress.return_value = { + mock_search_engine.get_dataset_progress.return_value = { "total": len(records), + } + + mock_search_engine.get_dataset_user_progress.return_value = { "submitted": 2, "discarded": 1, "draft": 1,