Skip to content

Commit

Permalink
perf: Using search engine to compute the total number of records for …
Browse files Browse the repository at this point in the history
…user metrics (#5641)

# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

Using search engine total instead of DB count to get the total number of
records. This improves the performance when running the Argilla server
on an HF space with persistent storage enabled.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Oct 28, 2024
1 parent 432f557 commit 9bc6ff3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ async def get_current_user_dataset_metrics(

await authorize(current_user, DatasetPolicy.get(dataset))

result = await datasets.get_user_dataset_metrics(db, search_engine, current_user, dataset)
result = await datasets.get_user_dataset_metrics(search_engine, current_user, dataset)

return DatasetMetrics(responses=result)

Expand Down
3 changes: 1 addition & 2 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,11 @@ async def _configure_query_relationships(


async def get_user_dataset_metrics(
db: AsyncSession,
search_engine: SearchEngine,
user: User,
dataset: Dataset,
) -> dict:
total_records = await Record.count_by(db, dataset_id=dataset.id)
total_records = (await get_dataset_progress(search_engine, dataset))["total"]
result = await search_engine.get_dataset_user_progress(dataset, user)

submitted_responses = result.get("submitted", 0)
Expand Down
12 changes: 9 additions & 3 deletions argilla-server/tests/unit/api/handlers/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,8 @@ async def test_get_current_user_dataset_metrics(
dataset = await DatasetFactory.create()
records = await RecordFactory.create_batch(size=8, dataset=dataset)

mock_search_engine.get_dataset_progress.return_value = {"total": len(records)}

mock_search_engine.get_dataset_user_progress.return_value = {
"total": 6,
"submitted": 3,
Expand Down Expand Up @@ -772,6 +774,7 @@ async def test_get_current_user_dataset_metrics_with_empty_dataset(
):
dataset = await DatasetFactory.create()

mock_search_engine.get_dataset_progress.return_value = {}
mock_search_engine.get_dataset_user_progress.return_value = {}

response = await async_client.get(
Expand All @@ -791,7 +794,7 @@ async def test_get_current_user_dataset_metrics_with_empty_dataset(
}

@pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin])
async def test_get_current_user_dataset_metrics_as_annotator(
async def test_get_current_user_dataset_metrics_as_different_role(
self,
async_client: "AsyncClient",
mock_search_engine: SearchEngine,
Expand All @@ -800,10 +803,13 @@ async def test_get_current_user_dataset_metrics_as_annotator(
dataset = await DatasetFactory.create()
records = await RecordFactory.create_batch(size=6, dataset=dataset)

user = await AnnotatorFactory.create(workspaces=[dataset.workspace], role=role)
user = await UserFactory.create(workspaces=[dataset.workspace], role=role)

mock_search_engine.get_dataset_user_progress.return_value = {
mock_search_engine.get_dataset_progress.return_value = {
"total": len(records),
}

mock_search_engine.get_dataset_user_progress.return_value = {
"submitted": 2,
"discarded": 1,
"draft": 1,
Expand Down

0 comments on commit 9bc6ff3

Please sign in to comment.