Skip to content

Commit

Permalink
Merge branch 'develop' into feat/argilla-direct-feature-branch
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon authored Oct 28, 2024
2 parents 06f4aaa + 9bc6ff3 commit 1bbb9d1
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def get_current_user_dataset_metrics(

await authorize(current_user, DatasetPolicy.get(dataset))

result = await datasets.get_user_dataset_metrics(db, search_engine, current_user, dataset)
result = await datasets.get_user_dataset_metrics(search_engine, current_user, dataset)

return DatasetMetrics(responses=result)

Expand Down
3 changes: 1 addition & 2 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,11 @@ async def _configure_query_relationships(


async def get_user_dataset_metrics(
db: AsyncSession,
search_engine: SearchEngine,
user: User,
dataset: Dataset,
) -> dict:
total_records = await Record.count_by(db, dataset_id=dataset.id)
total_records = (await get_dataset_progress(search_engine, dataset))["total"]
result = await search_engine.get_dataset_user_progress(dataset, user)

submitted_responses = result.get("submitted", 0)
Expand Down
12 changes: 9 additions & 3 deletions argilla-server/tests/unit/api/handlers/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,8 @@ async def test_get_current_user_dataset_metrics(
dataset = await DatasetFactory.create()
records = await RecordFactory.create_batch(size=8, dataset=dataset)

mock_search_engine.get_dataset_progress.return_value = {"total": len(records)}

mock_search_engine.get_dataset_user_progress.return_value = {
"total": 6,
"submitted": 3,
Expand Down Expand Up @@ -776,6 +778,7 @@ async def test_get_current_user_dataset_metrics_with_empty_dataset(
):
dataset = await DatasetFactory.create()

mock_search_engine.get_dataset_progress.return_value = {}
mock_search_engine.get_dataset_user_progress.return_value = {}

response = await async_client.get(
Expand All @@ -795,7 +798,7 @@ async def test_get_current_user_dataset_metrics_with_empty_dataset(
}

@pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin])
async def test_get_current_user_dataset_metrics_as_annotator(
async def test_get_current_user_dataset_metrics_as_different_role(
self,
async_client: "AsyncClient",
mock_search_engine: SearchEngine,
Expand All @@ -804,10 +807,13 @@ async def test_get_current_user_dataset_metrics_as_annotator(
dataset = await DatasetFactory.create()
records = await RecordFactory.create_batch(size=6, dataset=dataset)

user = await AnnotatorFactory.create(workspaces=[dataset.workspace], role=role)
user = await UserFactory.create(workspaces=[dataset.workspace], role=role)

mock_search_engine.get_dataset_user_progress.return_value = {
mock_search_engine.get_dataset_progress.return_value = {
"total": len(records),
}

mock_search_engine.get_dataset_user_progress.return_value = {
"submitted": 2,
"discarded": 1,
"draft": 1,
Expand Down
1 change: 1 addition & 0 deletions argilla/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ These are the section headers that we use:
### Changed

- Terms metadata properties accept other values than `str`. ([#5594](https://github.com/argilla-io/argilla/pull/5594))
- Added support for `with_vectors` while fetching records along with a search query. ([#5638](https://github.com/argilla-io/argilla/pull/5638))

### Removed

Expand Down
4 changes: 3 additions & 1 deletion argilla/src/argilla/_api/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@ def search(
limit: int = 100,
with_suggestions: bool = True,
with_responses: bool = True,
# TODO: Add support for `with_vectors`
with_vectors: Optional[Union[List, bool]] = None,
) -> Tuple[List[Tuple[RecordModel, float]], int]:
include = []
if with_suggestions:
include.append("suggestions")
if with_responses:
include.append("responses")
if with_vectors:
include.append(self._represent_vectors_to_include(with_vectors))
params = {
"offset": offset,
"limit": limit,
Expand Down
1 change: 1 addition & 0 deletions argilla/src/argilla/records/_dataset_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _fetch_from_server_with_search(self) -> List[RecordModel]:
offset=self.__offset,
with_responses=self.__with_responses,
with_suggestions=self.__with_suggestions,
with_vectors=self.__with_vectors,
)
return [record_model for record_model, _ in search_items]

Expand Down

0 comments on commit 1bbb9d1

Please sign in to comment.