From f62d58a2f91e16eb4d02e56c4039a432070349c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Mon, 1 Jul 2024 12:31:02 +0200 Subject: [PATCH 01/34] feat: add dataset support to be created using distribution settings (#5013) # Description This PR is the first one related with distribution task feature, adding the following changes: * Added `distribution` JSON column to `datasets` table: * This column is non-nullable so a value is always required when a dataset is created. * By default old datasets will have the value `{"strategy": "overlap", "min_submitted": 1}`. * Added `distribution` attribute to `DatasetCreate` schema: * None is not a valid value. * If no value is specified for this attribute `DatasetOverlapDistributionCreate` with `min_submitted` to `1` is used. * `DatasetOverlapDistributionCreate` only allows values greater or equal than `1` for `min_submitted` attributed. * Now the context `create_dataset` function is receiving a dictionary instead of `DatasetCreate` schema. * Moved dataset creation validations to a new `DatasetCreateValidator` class. Update of `distribution` attribute for datasets will be done in a different issue. Closes #5005 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [x] Adding new tests and passing old ones. - [x] Check that migration works as expected with old datasets and SQLite. - [x] Check that migration works as expected with old datasets and PostgreSQL. **Checklist** - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paco Aranda --- .../repositories/RecordRepository.ts | 30 ++- argilla-server/CHANGELOG.md | 7 +- ...4d74_add_status_column_to_records_table.py | 60 ++++++ ...7_add_metadata_column_to_records_table.py} | 6 +- ...d_distribution_column_to_datasets_table.py | 45 +++++ ...xtra_metadata_column_to_datasets_table.py} | 6 +- .../api/handlers/v1/datasets/datasets.py | 4 +- .../api/handlers/v1/responses.py | 8 +- .../argilla_server/api/schemas/v1/datasets.py | 38 +++- .../argilla_server/api/schemas/v1/records.py | 5 +- .../src/argilla_server/bulk/records_bulk.py | 4 + .../src/argilla_server/contexts/datasets.py | 56 +++--- .../argilla_server/contexts/distribution.py | 42 +++++ argilla-server/src/argilla_server/enums.py | 9 + .../src/argilla_server/models/database.py | 31 ++- .../src/argilla_server/search_engine/base.py | 4 + .../argilla_server/search_engine/commons.py | 6 + .../src/argilla_server/validators/datasets.py | 48 +++++ argilla-server/tests/factories.py | 3 +- .../records_bulk/test_dataset_records_bulk.py | 3 +- .../v1/datasets/test_create_dataset.py | 139 ++++++++++++++ ...est_search_current_user_dataset_records.py | 5 +- .../datasets/test_search_dataset_records.py | 4 +- .../v1/datasets/test_update_dataset.py | 178 ++++++++++++++++++ .../test_create_dataset_records_bulk.py | 145 ++++++++++++++ .../v1/records/test_create_record_response.py | 100 ++++++++-- .../test_upsert_dataset_records_bulk.py | 153 +++++++++++++++ ...test_create_current_user_responses_bulk.py | 10 +- .../v1/responses/test_delete_response.py | 66 +++++++ .../v1/responses/test_update_response.py | 71 ++++++- .../unit/api/handlers/v1/test_datasets.py | 39 +++- .../handlers/v1/test_list_dataset_records.py | 6 +- .../unit/api/handlers/v1/test_records.py | 10 +- .../tests/unit/search_engine/test_commons.py | 12 +- argilla/src/argilla/_models/_search.py | 6 + argilla/src/argilla/records/_search.py | 4 +- 36 files changed, 1274 insertions(+), 89 deletions(-) create mode 100644 argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py rename argilla-server/src/argilla_server/alembic/versions/{3ff6484f8b37_add_record_metadata_column.py => 3ff6484f8b37_add_metadata_column_to_records_table.py} (82%) create mode 100644 argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py rename argilla-server/src/argilla_server/alembic/versions/{b8458008b60e_add_allow_extra_metadata_column_to_.py => b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py} (81%) create mode 100644 argilla-server/src/argilla_server/contexts/distribution.py create mode 100644 argilla-server/src/argilla_server/validators/datasets.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py diff --git a/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts b/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts index e0e30adfd3..40ce2645eb 100644 --- a/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts +++ b/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts @@ -42,10 +42,8 @@ export class RecordRepository { constructor(private readonly axios: NuxtAxiosInstance) {} getRecords(criteria: RecordCriteria): Promise { - if (criteria.isFilteringByAdvanceSearch) - return this.getRecordsByAdvanceSearch(criteria); - - return this.getRecordsByDatasetId(criteria); + return this.getRecordsByAdvanceSearch(criteria); + // return this.getRecordsByDatasetId(criteria); } async getRecord(recordId: string): Promise { @@ -264,6 +262,30 @@ export class RecordRepository { }; } + body.filters = { + and: [ + { + type: "terms", + scope: { + entity: "response", + property: "status", + }, + values: [status], + }, + ], + }; + + if (status === "pending") { + body.filters.and.push({ + type: "terms", + scope: { + entity: "record", + property: "status", + }, + values: ["pending"], + }); + } + if ( isFilteringByMetadata || isFilteringByResponse || diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index de84587e41..827037a2c3 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -16,12 +16,17 @@ These are the section headers that we use: ## [Unreleased]() -## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) +### Added + +- Added support to specify `distribution` attribute when creating a dataset. ([#5013](https://github.com/argilla-io/argilla/pull/5013)) +- Added support to change `distribution` attribute when updating a dataset. ([#5028](https://github.com/argilla-io/argilla/pull/5028)) ### Changed - Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) +## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) + ### Removed - Removed all API v0 endpoints. ([#4852](https://github.com/argilla-io/argilla/pull/4852)) diff --git a/argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py b/argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py new file mode 100644 index 0000000000..767b277573 --- /dev/null +++ b/argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py @@ -0,0 +1,60 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""add status column to records table + +Revision ID: 237f7c674d74 +Revises: 45a12f74448b +Create Date: 2024-06-18 17:59:36.992165 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "237f7c674d74" +down_revision = "45a12f74448b" +branch_labels = None +depends_on = None + + +record_status_enum = sa.Enum("pending", "completed", name="record_status_enum") + + +def upgrade() -> None: + record_status_enum.create(op.get_bind()) + + op.add_column("records", sa.Column("status", record_status_enum, server_default="pending", nullable=False)) + op.create_index(op.f("ix_records_status"), "records", ["status"], unique=False) + + # NOTE: Updating existent records to have "completed" status when they have + # at least one response with "submitted" status. + op.execute(""" + UPDATE records + SET status = 'completed' + WHERE id IN ( + SELECT DISTINCT record_id + FROM responses + WHERE status = 'submitted' + ); + """) + + +def downgrade() -> None: + op.drop_index(op.f("ix_records_status"), table_name="records") + op.drop_column("records", "status") + + record_status_enum.drop(op.get_bind()) diff --git a/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py b/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py similarity index 82% rename from argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py rename to argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py index 7ac80ad895..b5949f5364 100644 --- a/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py +++ b/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""add record metadata column +"""add metadata column to records table Revision ID: 3ff6484f8b37 Revises: ae5522b4c674 @@ -31,12 +31,8 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.add_column("records", sa.Column("metadata", sa.JSON(), nullable=True)) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.drop_column("records", "metadata") - # ### end Alembic commands ### diff --git a/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py b/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py new file mode 100644 index 0000000000..791da07439 --- /dev/null +++ b/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py @@ -0,0 +1,45 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""add distribution column to datasets table + +Revision ID: 45a12f74448b +Revises: d00f819ccc67 +Create Date: 2024-06-13 11:23:43.395093 + +""" + +import json + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "45a12f74448b" +down_revision = "d00f819ccc67" +branch_labels = None +depends_on = None + +DISTRIBUTION_VALUE = json.dumps({"strategy": "overlap", "min_submitted": 1}) + + +def upgrade() -> None: + op.add_column("datasets", sa.Column("distribution", sa.JSON(), nullable=True)) + op.execute(f"UPDATE datasets SET distribution = '{DISTRIBUTION_VALUE}'") + with op.batch_alter_table("datasets") as batch_op: + batch_op.alter_column("distribution", nullable=False) + + +def downgrade() -> None: + op.drop_column("datasets", "distribution") diff --git a/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py b/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py similarity index 81% rename from argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py rename to argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py index 8b23340448..f8fa87536e 100644 --- a/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py +++ b/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""add allow_extra_metadata column to dataset table +"""add allow_extra_metadata column to datasets table Revision ID: b8458008b60e Revises: 7cbcccf8b57a @@ -31,14 +31,10 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.add_column( "datasets", sa.Column("allow_extra_metadata", sa.Boolean(), server_default=sa.text("true"), nullable=False) ) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.drop_column("datasets", "allow_extra_metadata") - # ### end Alembic commands ### diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py index 63f95391e1..0590b41bb4 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py @@ -189,7 +189,7 @@ async def create_dataset( ): await authorize(current_user, DatasetPolicy.create(dataset_create.workspace_id)) - return await datasets.create_dataset(db, dataset_create) + return await datasets.create_dataset(db, dataset_create.dict()) @router.post("/datasets/{dataset_id}/fields", status_code=status.HTTP_201_CREATED, response_model=Field) @@ -302,4 +302,4 @@ async def update_dataset( await authorize(current_user, DatasetPolicy.update(dataset)) - return await datasets.update_dataset(db, dataset, dataset_update) + return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True)) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/responses.py b/argilla-server/src/argilla_server/api/handlers/v1/responses.py index 56cb695c95..ddc389563a 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/responses.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/responses.py @@ -64,7 +64,9 @@ async def update_response( response = await Response.get_or_raise( db, response_id, - options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)], + options=[ + selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions), + ], ) await authorize(current_user, ResponsePolicy.update(response)) @@ -83,7 +85,9 @@ async def delete_response( response = await Response.get_or_raise( db, response_id, - options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)], + options=[ + selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions), + ], ) await authorize(current_user, ResponsePolicy.delete(response)) diff --git a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py index 5cac33bdb7..1e1b69d836 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py @@ -13,11 +13,11 @@ # limitations under the License. from datetime import datetime -from typing import List, Optional +from typing import List, Literal, Optional, Union from uuid import UUID from argilla_server.api.schemas.v1.commons import UpdateSchema -from argilla_server.enums import DatasetStatus +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus from argilla_server.pydantic_v1 import BaseModel, Field, constr try: @@ -44,6 +44,32 @@ ] +class DatasetOverlapDistribution(BaseModel): + strategy: Literal[DatasetDistributionStrategy.overlap] + min_submitted: int + + +DatasetDistribution = DatasetOverlapDistribution + + +class DatasetOverlapDistributionCreate(BaseModel): + strategy: Literal[DatasetDistributionStrategy.overlap] + min_submitted: int = Field( + ge=1, + description="Minimum number of submitted responses to consider a record as completed", + ) + + +DatasetDistributionCreate = DatasetOverlapDistributionCreate + + +class DatasetOverlapDistributionUpdate(DatasetDistributionCreate): + pass + + +DatasetDistributionUpdate = DatasetOverlapDistributionUpdate + + class RecordMetrics(BaseModel): count: int @@ -74,6 +100,7 @@ class Dataset(BaseModel): guidelines: Optional[str] allow_extra_metadata: bool status: DatasetStatus + distribution: DatasetDistribution workspace_id: UUID last_activity_at: datetime inserted_at: datetime @@ -91,6 +118,10 @@ class DatasetCreate(BaseModel): name: DatasetName guidelines: Optional[DatasetGuidelines] allow_extra_metadata: bool = True + distribution: DatasetDistributionCreate = DatasetOverlapDistributionCreate( + strategy=DatasetDistributionStrategy.overlap, + min_submitted=1, + ) workspace_id: UUID @@ -98,5 +129,6 @@ class DatasetUpdate(UpdateSchema): name: Optional[DatasetName] guidelines: Optional[DatasetGuidelines] allow_extra_metadata: Optional[bool] + distribution: Optional[DatasetDistributionUpdate] - __non_nullable_fields__ = {"name", "allow_extra_metadata"} + __non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"} diff --git a/argilla-server/src/argilla_server/api/schemas/v1/records.py b/argilla-server/src/argilla_server/api/schemas/v1/records.py index 13f37c3ae0..0cf215954a 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/records.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/records.py @@ -23,7 +23,7 @@ from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName from argilla_server.api.schemas.v1.responses import Response, ResponseFilterScope, UserResponseCreate from argilla_server.api.schemas.v1.suggestions import Suggestion, SuggestionCreate, SuggestionFilterScope -from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder +from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder, RecordStatus from argilla_server.pydantic_v1 import BaseModel, Field, StrictStr, root_validator, validator from argilla_server.pydantic_v1.utils import GetterDict from argilla_server.search_engine import TextQuery @@ -66,6 +66,7 @@ def get(self, key: str, default: Any) -> Any: class Record(BaseModel): id: UUID + status: RecordStatus fields: Dict[str, Any] metadata: Optional[Dict[str, Any]] external_id: Optional[str] @@ -196,7 +197,7 @@ def _has_relationships(self): class RecordFilterScope(BaseModel): entity: Literal["record"] - property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at]] + property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at], Literal["status"]] class Records(BaseModel): diff --git a/argilla-server/src/argilla_server/bulk/records_bulk.py b/argilla-server/src/argilla_server/bulk/records_bulk.py index 0e3d372be5..6acbc30031 100644 --- a/argilla-server/src/argilla_server/bulk/records_bulk.py +++ b/argilla-server/src/argilla_server/bulk/records_bulk.py @@ -29,6 +29,7 @@ ) from argilla_server.api.schemas.v1.responses import UserResponseCreate from argilla_server.api.schemas.v1.suggestions import SuggestionCreate +from argilla_server.contexts import distribution from argilla_server.contexts.accounts import fetch_users_by_ids_as_dict from argilla_server.contexts.records import ( fetch_records_by_external_ids_as_dict, @@ -67,6 +68,7 @@ async def create_records_bulk(self, dataset: Dataset, bulk_create: RecordsBulkCr await self._upsert_records_relationships(records, bulk_create.items) await _preload_records_relationships_before_index(self._db, records) + await distribution.update_records_status(self._db, records) await self._search_engine.index_records(dataset, records) await self._db.commit() @@ -207,6 +209,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp await self._upsert_records_relationships(records, bulk_upsert.items) await _preload_records_relationships_before_index(self._db, records) + await distribution.update_records_status(self._db, records) await self._search_engine.index_records(dataset, records) await self._db.commit() @@ -237,6 +240,7 @@ async def _preload_records_relationships_before_index(db: "AsyncSession", record .filter(Record.id.in_([record.id for record in records])) .options( selectinload(Record.responses).selectinload(Response.user), + selectinload(Record.responses_submitted), selectinload(Record.suggestions).selectinload(Suggestion.question), selectinload(Record.vectors), ) diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 34468c2b18..1dbf52fc53 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -37,10 +37,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload, selectinload -from argilla_server.api.schemas.v1.datasets import ( - DatasetCreate, - DatasetProgress, -) +from argilla_server.api.schemas.v1.datasets import DatasetProgress from argilla_server.api.schemas.v1.fields import FieldCreate from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyCreate, MetadataPropertyUpdate from argilla_server.api.schemas.v1.records import ( @@ -63,7 +60,7 @@ VectorSettingsCreate, ) from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema -from argilla_server.contexts import accounts +from argilla_server.contexts import accounts, distribution from argilla_server.enums import DatasetStatus, RecordInclude, UserRole from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError from argilla_server.models import ( @@ -82,6 +79,7 @@ ) from argilla_server.models.suggestions import SuggestionCreateWithRecordId from argilla_server.search_engine import SearchEngine +from argilla_server.validators.datasets import DatasetCreateValidator, DatasetUpdateValidator from argilla_server.validators.responses import ( ResponseCreateValidator, ResponseUpdateValidator, @@ -122,22 +120,18 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) -> return result.scalars().all() -async def create_dataset(db: AsyncSession, dataset_create: DatasetCreate): - if await Workspace.get(db, dataset_create.workspace_id) is None: - raise UnprocessableEntityError(f"Workspace with id `{dataset_create.workspace_id}` not found") +async def create_dataset(db: AsyncSession, dataset_attrs: dict): + dataset = Dataset( + name=dataset_attrs["name"], + guidelines=dataset_attrs["guidelines"], + allow_extra_metadata=dataset_attrs["allow_extra_metadata"], + distribution=dataset_attrs["distribution"], + workspace_id=dataset_attrs["workspace_id"], + ) - if await Dataset.get_by(db, name=dataset_create.name, workspace_id=dataset_create.workspace_id): - raise NotUniqueError( - f"Dataset with name `{dataset_create.name}` already exists for workspace with id `{dataset_create.workspace_id}`" - ) + await DatasetCreateValidator.validate(db, dataset) - return await Dataset.create( - db, - name=dataset_create.name, - guidelines=dataset_create.guidelines, - allow_extra_metadata=dataset_create.allow_extra_metadata, - workspace_id=dataset_create.workspace_id, - ) + return await dataset.save(db) async def _count_required_fields_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int: @@ -176,6 +170,12 @@ async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset return dataset +async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> Dataset: + await DatasetUpdateValidator.validate(db, dataset, dataset_attrs) + + return await dataset.update(db, **dataset_attrs) + + async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset: async with db.begin_nested(): dataset = await dataset.delete(db, autocommit=False) @@ -186,11 +186,6 @@ async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: return dataset -async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_update: "DatasetUpdate") -> Dataset: - params = dataset_update.dict(exclude_unset=True) - return await dataset.update(db, **params) - - async def create_field(db: AsyncSession, dataset: Dataset, field_create: FieldCreate) -> Field: if dataset.is_ready: raise UnprocessableEntityError("Field cannot be created for a published dataset") @@ -945,6 +940,9 @@ async def create_response( await db.flush([response]) await _touch_dataset_last_activity_at(db, record.dataset) await search_engine.update_record_response(response) + await db.refresh(record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, record) + await search_engine.partial_record_update(record, status=record.status) await db.commit() @@ -968,6 +966,9 @@ async def update_response( await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) await search_engine.update_record_response(response) + await db.refresh(response.record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, response.record) + await search_engine.partial_record_update(response.record, status=response.record.status) await db.commit() @@ -997,6 +998,9 @@ async def upsert_response( await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) await search_engine.update_record_response(response) + await db.refresh(record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, record) + await search_engine.partial_record_update(record, status=record.status) await db.commit() @@ -1006,9 +1010,13 @@ async def upsert_response( async def delete_response(db: AsyncSession, search_engine: SearchEngine, response: Response) -> Response: async with db.begin_nested(): response = await response.delete(db, autocommit=False) + await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) await search_engine.delete_record_response(response) + await db.refresh(response.record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, response.record) + await search_engine.partial_record_update(record=response.record, status=response.record.status) await db.commit() diff --git a/argilla-server/src/argilla_server/contexts/distribution.py b/argilla-server/src/argilla_server/contexts/distribution.py new file mode 100644 index 0000000000..92973801ce --- /dev/null +++ b/argilla-server/src/argilla_server/contexts/distribution.py @@ -0,0 +1,42 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.enums import DatasetDistributionStrategy, RecordStatus +from argilla_server.models import Record + + +# TODO: Do this with one single update statement for all records if possible to avoid too many queries. +async def update_records_status(db: AsyncSession, records: List[Record]): + for record in records: + await update_record_status(db, record) + + +async def update_record_status(db: AsyncSession, record: Record) -> Record: + if record.dataset.distribution_strategy == DatasetDistributionStrategy.overlap: + return await _update_record_status_with_overlap_strategy(db, record) + + raise NotImplementedError(f"unsupported distribution strategy `{record.dataset.distribution_strategy}`") + + +async def _update_record_status_with_overlap_strategy(db: AsyncSession, record: Record) -> Record: + if len(record.responses_submitted) >= record.dataset.distribution["min_submitted"]: + record.status = RecordStatus.completed + else: + record.status = RecordStatus.pending + + return await record.save(db, autocommit=False) diff --git a/argilla-server/src/argilla_server/enums.py b/argilla-server/src/argilla_server/enums.py index 13b4843280..fcf0b3142f 100644 --- a/argilla-server/src/argilla_server/enums.py +++ b/argilla-server/src/argilla_server/enums.py @@ -43,12 +43,21 @@ class DatasetStatus(str, Enum): ready = "ready" +class DatasetDistributionStrategy(str, Enum): + overlap = "overlap" + + class UserRole(str, Enum): owner = "owner" admin = "admin" annotator = "annotator" +class RecordStatus(str, Enum): + pending = "pending" + completed = "completed" + + class RecordInclude(str, Enum): responses = "responses" suggestions = "suggestions" diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 468b682467..37bd7730c9 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -29,9 +29,12 @@ DatasetStatus, MetadataPropertyType, QuestionType, + RecordStatus, ResponseStatus, SuggestionType, UserRole, + DatasetDistributionStrategy, + RecordStatus, ) from argilla_server.models.base import DatabaseModel from argilla_server.models.metadata_properties import MetadataPropertySettings @@ -180,11 +183,17 @@ def __repr__(self) -> str: ) +RecordStatusEnum = SAEnum(RecordStatus, name="record_status_enum") + + class Record(DatabaseModel): __tablename__ = "records" fields: Mapped[dict] = mapped_column(JSON, default={}) metadata_: Mapped[Optional[dict]] = mapped_column("metadata", MutableDict.as_mutable(JSON), nullable=True) + status: Mapped[RecordStatus] = mapped_column( + RecordStatusEnum, default=RecordStatus.pending, server_default=RecordStatus.pending, index=True + ) external_id: Mapped[Optional[str]] = mapped_column(index=True) dataset_id: Mapped[UUID] = mapped_column(ForeignKey("datasets.id", ondelete="CASCADE"), index=True) @@ -195,6 +204,13 @@ class Record(DatabaseModel): passive_deletes=True, order_by=Response.inserted_at.asc(), ) + responses_submitted: Mapped[List["Response"]] = relationship( + back_populates="record", + cascade="all, delete-orphan", + passive_deletes=True, + primaryjoin=f"and_(Record.id==Response.record_id, Response.status=='{ResponseStatus.submitted}')", + order_by=Response.inserted_at.asc(), + ) suggestions: Mapped[List["Suggestion"]] = relationship( back_populates="record", cascade="all, delete-orphan", @@ -210,17 +226,17 @@ class Record(DatabaseModel): __table_args__ = (UniqueConstraint("external_id", "dataset_id", name="record_external_id_dataset_id_uq"),) + def vector_value_by_vector_settings(self, vector_settings: "VectorSettings") -> Union[List[float], None]: + for vector in self.vectors: + if vector.vector_settings_id == vector_settings.id: + return vector.value + def __repr__(self): return ( f"Record(id={str(self.id)!r}, external_id={self.external_id!r}, dataset_id={str(self.dataset_id)!r}, " f"inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})" ) - def vector_value_by_vector_settings(self, vector_settings: "VectorSettings") -> Union[List[float], None]: - for vector in self.vectors: - if vector.vector_settings_id == vector_settings.id: - return vector.value - class Question(DatabaseModel): __tablename__ = "questions" @@ -304,6 +320,7 @@ class Dataset(DatabaseModel): guidelines: Mapped[Optional[str]] = mapped_column(Text) allow_extra_metadata: Mapped[bool] = mapped_column(default=True, server_default=sql.true()) status: Mapped[DatasetStatus] = mapped_column(DatasetStatusEnum, default=DatasetStatus.draft, index=True) + distribution: Mapped[dict] = mapped_column(MutableDict.as_mutable(JSON)) workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id", ondelete="CASCADE"), index=True) inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) updated_at: Mapped[datetime] = mapped_column(default=inserted_at_current_value, onupdate=datetime.utcnow) @@ -353,6 +370,10 @@ def is_draft(self): def is_ready(self): return self.status == DatasetStatus.ready + @property + def distribution_strategy(self) -> DatasetDistributionStrategy: + return DatasetDistributionStrategy(self.distribution["strategy"]) + def metadata_property_by_name(self, name: str) -> Union["MetadataProperty", None]: for metadata_property in self.metadata_properties: if metadata_property.name == name: diff --git a/argilla-server/src/argilla_server/search_engine/base.py b/argilla-server/src/argilla_server/search_engine/base.py index 7c9146cafe..ee1dbcc386 100644 --- a/argilla-server/src/argilla_server/search_engine/base.py +++ b/argilla-server/src/argilla_server/search_engine/base.py @@ -317,6 +317,10 @@ async def configure_metadata_property(self, dataset: Dataset, metadata_property: async def index_records(self, dataset: Dataset, records: Iterable[Record]): pass + @abstractmethod + async def partial_record_update(self, record: Record, **update): + pass + @abstractmethod async def delete_records(self, dataset: Dataset, records: Iterable[Record]): pass diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index 2030b59ae5..b328224f19 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -346,6 +346,10 @@ async def index_records(self, dataset: Dataset, records: Iterable[Record]): await self._bulk_op_request(bulk_actions) + async def partial_record_update(self, record: Record, **update): + index_name = await self._get_dataset_index(record.dataset) + await self._update_document_request(index_name=index_name, id=str(record.id), body={"doc": update}) + async def delete_records(self, dataset: Dataset, records: Iterable[Record]): index_name = await self._get_dataset_index(dataset) @@ -552,6 +556,7 @@ def _map_record_to_es_document(self, record: Record) -> Dict[str, Any]: document = { "id": str(record.id), "fields": record.fields, + "status": record.status, "inserted_at": record.inserted_at, "updated_at": record.updated_at, } @@ -712,6 +717,7 @@ def _configure_index_mappings(self, dataset: Dataset) -> dict: "properties": { # See https://www.elastic.co/guide/en/elasticsearch/reference/current/explicit-mapping.html "id": {"type": "keyword"}, + "status": {"type": "keyword"}, RecordSortField.inserted_at.value: {"type": "date_nanos"}, RecordSortField.updated_at.value: {"type": "date_nanos"}, "responses": {"dynamic": True, "type": "object"}, diff --git a/argilla-server/src/argilla_server/validators/datasets.py b/argilla-server/src/argilla_server/validators/datasets.py new file mode 100644 index 0000000000..aae2a5fc83 --- /dev/null +++ b/argilla-server/src/argilla_server/validators/datasets.py @@ -0,0 +1,48 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError +from argilla_server.models import Dataset, Workspace + + +class DatasetCreateValidator: + @classmethod + async def validate(cls, db: AsyncSession, dataset: Dataset) -> None: + await cls._validate_workspace_is_present(db, dataset.workspace_id) + await cls._validate_name_is_not_duplicated(db, dataset.name, dataset.workspace_id) + + @classmethod + async def _validate_workspace_is_present(cls, db: AsyncSession, workspace_id: UUID) -> None: + if await Workspace.get(db, workspace_id) is None: + raise UnprocessableEntityError(f"Workspace with id `{workspace_id}` not found") + + @classmethod + async def _validate_name_is_not_duplicated(cls, db: AsyncSession, name: str, workspace_id: UUID) -> None: + if await Dataset.get_by(db, name=name, workspace_id=workspace_id): + raise NotUniqueError(f"Dataset with name `{name}` already exists for workspace with id `{workspace_id}`") + + +class DatasetUpdateValidator: + @classmethod + async def validate(cls, db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> None: + cls._validate_distribution(dataset, dataset_attrs) + + @classmethod + def _validate_distribution(cls, dataset: Dataset, dataset_attrs: dict) -> None: + if dataset.is_ready and dataset_attrs.get("distribution") is not None: + raise UnprocessableEntityError(f"Distribution settings cannot be modified for a published dataset") diff --git a/argilla-server/tests/factories.py b/argilla-server/tests/factories.py index 5c77b9a0f5..c429fed9af 100644 --- a/argilla-server/tests/factories.py +++ b/argilla-server/tests/factories.py @@ -16,7 +16,7 @@ import random import factory -from argilla_server.enums import FieldType, MetadataPropertyType, OptionsOrder +from argilla_server.enums import DatasetDistributionStrategy, FieldType, MetadataPropertyType, OptionsOrder from argilla_server.models import ( Dataset, Field, @@ -203,6 +203,7 @@ class Meta: model = Dataset name = factory.Sequence(lambda n: f"dataset-{n}") + distribution = {"strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1} workspace = factory.SubFactory(WorkspaceFactory) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py index d7e95520d5..3d1f0bf6da 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py @@ -15,7 +15,7 @@ from uuid import UUID import pytest -from argilla_server.enums import DatasetStatus +from argilla_server.enums import DatasetStatus, RecordStatus from argilla_server.models import Dataset, Record from httpx import AsyncClient from sqlalchemy import func, select @@ -87,6 +87,7 @@ async def test_create_dataset_records_bulk( "items": [ { "id": str(record.id), + "status": RecordStatus.pending, "dataset_id": str(dataset.id), "external_id": record.external_id, "fields": record.fields, diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py new file mode 100644 index 0000000000..4261145d0c --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py @@ -0,0 +1,139 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus +from argilla_server.models import Dataset +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import WorkspaceFactory + + +@pytest.mark.asyncio +class TestCreateDataset: + def url(self) -> str: + return "/api/v1/datasets" + + async def test_create_dataset_with_default_distribution( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "workspace_id": str(workspace.id), + }, + ) + + dataset = (await db.execute(select(Dataset))).scalar_one() + + assert response.status_code == 201 + assert response.json() == { + "id": str(dataset.id), + "name": "Dataset Name", + "guidelines": None, + "allow_extra_metadata": True, + "status": DatasetStatus.draft, + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + "workspace_id": str(workspace.id), + "last_activity_at": dataset.last_activity_at.isoformat(), + "inserted_at": dataset.inserted_at.isoformat(), + "updated_at": dataset.updated_at.isoformat(), + } + + async def test_create_dataset_with_overlap_distribution( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + "workspace_id": str(workspace.id), + }, + ) + + dataset = (await db.execute(select(Dataset))).scalar_one() + + assert response.status_code == 201 + assert response.json() == { + "id": str(dataset.id), + "name": "Dataset Name", + "guidelines": None, + "allow_extra_metadata": True, + "status": DatasetStatus.draft, + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + "workspace_id": str(workspace.id), + "last_activity_at": dataset.last_activity_at.isoformat(), + "inserted_at": dataset.inserted_at.isoformat(), + "updated_at": dataset.updated_at.isoformat(), + } + + async def test_create_dataset_with_overlap_distribution_using_invalid_min_submitted_value( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset name", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 0, + }, + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0 + + async def test_create_dataset_with_invalid_distribution_strategy( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "distribution": { + "strategy": "invalid_strategy", + }, + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0 diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py index e70072d814..8d4981e828 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py @@ -16,7 +16,7 @@ import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import UserRole +from argilla_server.enums import UserRole, RecordStatus from argilla_server.search_engine import SearchEngine, SearchResponseItem, SearchResponses from httpx import AsyncClient @@ -71,6 +71,7 @@ async def test_search_with_filtered_metadata( { "record": { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": record.metadata_, "external_id": record.external_id, @@ -122,6 +123,7 @@ async def test_search_with_filtered_metadata_as_annotator( { "record": { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": {"annotator_meta": "value"}, "external_id": record.external_id, @@ -173,6 +175,7 @@ async def test_search_with_filtered_metadata_as_admin( { "record": { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": {"admin_meta": "value", "annotator_meta": "value", "extra": "value"}, "external_id": record.external_id, diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 3d22527c3b..73077c4381 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -17,7 +17,7 @@ import pytest from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_LE from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import RecordInclude, SortOrder +from argilla_server.enums import RecordInclude, SortOrder, RecordStatus from argilla_server.search_engine import ( AndFilter, Order, @@ -118,6 +118,7 @@ async def test_with_include_responses( { "record": { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": { "sentiment": "neutral", "text": "This is a text", @@ -153,6 +154,7 @@ async def test_with_include_responses( { "record": { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": { "sentiment": "neutral", "text": "This is a text", diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py new file mode 100644 index 0000000000..cdb9b06ea2 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py @@ -0,0 +1,178 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +import pytest +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus +from httpx import AsyncClient + +from tests.factories import DatasetFactory + + +@pytest.mark.asyncio +class TestUpdateDataset: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}" + + async def test_update_dataset_distribution(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + }, + ) + + assert response.status_code == 200 + assert response.json()["distribution"] == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + } + + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + } + + async def test_update_dataset_without_distribution(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"name": "Dataset updated name"}, + ) + + assert response.status_code == 200 + assert response.json()["distribution"] == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + assert dataset.name == "Dataset updated name" + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_without_distribution_for_published_dataset( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"name": "Dataset updated name"}, + ) + + assert response.status_code == 200 + assert response.json()["distribution"] == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + assert dataset.name == "Dataset updated name" + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_for_published_dataset( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + }, + ) + + assert response.status_code == 422 + assert response.json() == {"detail": "Distribution settings cannot be modified for a published dataset"} + + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_with_invalid_strategy( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": "invalid_strategy", + }, + }, + ) + + assert response.status_code == 422 + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_with_invalid_min_submitted_value( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 0, + }, + }, + ) + + assert response.status_code == 422 + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_as_none(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"distribution": None}, + ) + + assert response.status_code == 422 + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py new file mode 100644 index 0000000000..1aae133535 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py @@ -0,0 +1,145 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.models import User, Record +from argilla_server.enums import DatasetDistributionStrategy, RecordStatus, ResponseStatus, DatasetStatus + +from tests.factories import AnnotatorFactory, DatasetFactory, TextFieldFactory, TextQuestionFactory + + +@pytest.mark.asyncio +class TestCreateDatasetRecordsBulk: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}/records/bulk" + + async def test_create_dataset_records_bulk_updates_records_status( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + status=DatasetStatus.ready, + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + }, + ) + + user = await AnnotatorFactory.create(workspaces=[dataset.workspace]) + + await TextFieldFactory.create(name="prompt", dataset=dataset) + await TextFieldFactory.create(name="response", dataset=dataset) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + }, + ], + }, + ) + + assert response.status_code == 201 + + response_items = response.json()["items"] + assert response_items[0]["status"] == RecordStatus.completed + assert response_items[1]["status"] == RecordStatus.pending + assert response_items[2]["status"] == RecordStatus.pending + assert response_items[3]["status"] == RecordStatus.pending + + assert (await Record.get(db, UUID(response_items[0]["id"]))).status == RecordStatus.completed + assert (await Record.get(db, UUID(response_items[1]["id"]))).status == RecordStatus.pending + assert (await Record.get(db, UUID(response_items[2]["id"]))).status == RecordStatus.pending + assert (await Record.get(db, UUID(response_items[3]["id"]))).status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py b/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py index 98b3a864b9..ce433d036d 100644 --- a/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py @@ -16,13 +16,15 @@ from uuid import UUID import pytest -from argilla_server.enums import ResponseStatusFilter -from argilla_server.models import Response, User + from httpx import AsyncClient from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession -from tests.factories import DatasetFactory, RecordFactory, SpanQuestionFactory +from argilla_server.enums import ResponseStatus, RecordStatus, DatasetDistributionStrategy +from argilla_server.models import Response, User + +from tests.factories import DatasetFactory, RecordFactory, SpanQuestionFactory, TextQuestionFactory @pytest.mark.asyncio @@ -52,7 +54,7 @@ async def test_create_record_response_for_span_question( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -72,7 +74,7 @@ async def test_create_record_response_for_span_question( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_json["inserted_at"]).isoformat(), @@ -101,7 +103,7 @@ async def test_create_record_response_for_span_question_with_additional_value_at ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -121,7 +123,7 @@ async def test_create_record_response_for_span_question_with_additional_value_at ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_json["inserted_at"]).isoformat(), @@ -146,7 +148,7 @@ async def test_create_record_response_for_span_question_with_empty_value( "value": [], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -162,7 +164,7 @@ async def test_create_record_response_for_span_question_with_empty_value( "value": [], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_json["inserted_at"]).isoformat(), @@ -189,7 +191,7 @@ async def test_create_record_response_for_span_question_with_record_not_providin ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -219,7 +221,7 @@ async def test_create_record_response_for_span_question_with_invalid_value( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -244,7 +246,7 @@ async def test_create_record_response_for_span_question_with_start_greater_than_ "value": [{"label": "label-a", "start": 5, "end": 6}], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -273,7 +275,7 @@ async def test_create_record_response_for_span_question_with_end_greater_than_ex "value": [{"label": "label-a", "start": 4, "end": 6}], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -304,7 +306,7 @@ async def test_create_record_response_for_span_question_with_invalid_start( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -331,7 +333,7 @@ async def test_create_record_response_for_span_question_with_invalid_end( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -358,7 +360,7 @@ async def test_create_record_response_for_span_question_with_equal_start_and_end ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -385,7 +387,7 @@ async def test_create_record_response_for_span_question_with_end_smaller_than_st ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -412,7 +414,7 @@ async def test_create_record_response_for_span_question_with_non_existent_label( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -446,7 +448,7 @@ async def test_create_record_response_for_span_question_with_overlapped_values( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -454,3 +456,63 @@ async def test_create_record_response_for_span_question_with_overlapped_values( assert response.json() == {"detail": "overlapping values found between spans at index idx=0 and idx=2"} assert (await db.execute(select(func.count(Response.id)))).scalar() == 0 + + async def test_create_record_response_updates_record_status_to_completed( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + response = await async_client.post( + self.url(record.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert response.status_code == 201 + assert record.status == RecordStatus.completed + + async def test_create_record_response_does_not_updates_record_status_to_completed( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + } + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + response = await async_client.post( + self.url(record.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert response.status_code == 201 + assert record.status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py new file mode 100644 index 0000000000..82b035a58a --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py @@ -0,0 +1,153 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient + +from argilla_server.models import User +from argilla_server.enums import DatasetDistributionStrategy, ResponseStatus, DatasetStatus, RecordStatus + +from tests.factories import DatasetFactory, RecordFactory, TextQuestionFactory, ResponseFactory, AnnotatorFactory + + +@pytest.mark.asyncio +class TestUpsertDatasetRecordsBulk: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}/records/bulk" + + async def test_upsert_dataset_records_bulk_updates_records_status( + self, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + status=DatasetStatus.ready, + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + }, + ) + + user = await AnnotatorFactory.create(workspaces=[dataset.workspace]) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record_a = await RecordFactory.create(dataset=dataset) + assert record_a.status == RecordStatus.pending + + await ResponseFactory.create( + user=owner, + record=record_a, + status=ResponseStatus.submitted, + values={ + "text-question": { + "value": "text question response", + }, + }, + ) + + record_b = await RecordFactory.create(dataset=dataset) + assert record_b.status == RecordStatus.pending + + record_c = await RecordFactory.create(dataset=dataset) + assert record_c.status == RecordStatus.pending + + record_d = await RecordFactory.create(dataset=dataset) + assert record_d.status == RecordStatus.pending + + response = await async_client.put( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "id": str(record_a.id), + "responses": [ + { + "user_id": str(user.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "id": str(record_b.id), + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "id": str(record_c.id), + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "id": str(record_d.id), + "responses": [], + }, + ], + }, + ) + + assert response.status_code == 200 + + respose_items = response.json()["items"] + assert respose_items[0]["status"] == RecordStatus.completed + assert respose_items[1]["status"] == RecordStatus.pending + assert respose_items[2]["status"] == RecordStatus.pending + assert respose_items[3]["status"] == RecordStatus.pending + + assert record_a.status == RecordStatus.completed + assert record_b.status == RecordStatus.pending + assert record_c.status == RecordStatus.pending + assert record_d.status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py index 009cec7d2e..07b4bf0199 100644 --- a/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py @@ -18,7 +18,7 @@ import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import ResponseStatus +from argilla_server.enums import ResponseStatus, RecordStatus from argilla_server.models import Response, User from argilla_server.search_engine import SearchEngine from argilla_server.use_cases.responses.upsert_responses_in_bulk import UpsertResponsesInBulkUseCase @@ -111,7 +111,7 @@ async def test_multiple_responses( "item": { "id": str(response_to_create_id), "values": {"prompt-quality": {"value": 5}}, - "status": ResponseStatus.submitted.value, + "status": ResponseStatus.submitted, "record_id": str(records[0].id), "user_id": str(annotator.id), "inserted_at": datetime.fromisoformat(resp_json["items"][0]["item"]["inserted_at"]).isoformat(), @@ -123,7 +123,7 @@ async def test_multiple_responses( "item": { "id": str(response_to_update.id), "values": {"prompt-quality": {"value": 10}}, - "status": ResponseStatus.submitted.value, + "status": ResponseStatus.submitted, "record_id": str(records[1].id), "user_id": str(annotator.id), "inserted_at": datetime.fromisoformat(resp_json["items"][1]["item"]["inserted_at"]).isoformat(), @@ -146,6 +146,10 @@ async def test_multiple_responses( ], } + assert records[0].status == RecordStatus.completed + assert records[1].status == RecordStatus.completed + assert records[2].status == RecordStatus.pending + assert (await db.execute(select(func.count(Response.id)))).scalar() == 2 response_to_create = (await db.execute(select(Response).filter_by(id=response_to_create_id))).scalar_one() diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py new file mode 100644 index 0000000000..6b9d4ec749 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py @@ -0,0 +1,66 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +import pytest + +from httpx import AsyncClient + +from argilla_server.models import User +from argilla_server.enums import DatasetDistributionStrategy, RecordStatus, ResponseStatus + +from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, TextQuestionFactory + + +@pytest.mark.asyncio +class TestDeleteResponse: + def url(self, response_id: UUID) -> str: + return f"/api/v1/responses/{response_id}" + + async def test_delete_response_updates_record_status_to_pending( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + ) + + record = await RecordFactory.create(status=RecordStatus.completed, dataset=dataset) + response = await ResponseFactory.create(record=record) + + resp = await async_client.delete(self.url(response.id), headers=owner_auth_header) + + assert resp.status_code == 200 + assert record.status == RecordStatus.pending + + async def test_delete_response_does_not_updates_record_status_to_pending( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + } + ) + + record = await RecordFactory.create(status=RecordStatus.completed, dataset=dataset) + responses = await ResponseFactory.create_batch(3, record=record) + + resp = await async_client.delete(self.url(responses[0].id), headers=owner_auth_header) + + assert resp.status_code == 200 + assert record.status == RecordStatus.completed diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py index f5ffab7b31..d5097f8c7b 100644 --- a/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py @@ -16,13 +16,15 @@ from uuid import UUID import pytest -from argilla_server.enums import ResponseStatus -from argilla_server.models import Response, User from httpx import AsyncClient + from sqlalchemy import select from sqlalchemy.ext.asyncio.session import AsyncSession -from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, SpanQuestionFactory +from argilla_server.enums import ResponseStatus, DatasetDistributionStrategy, RecordStatus +from argilla_server.models import Response, User + +from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, SpanQuestionFactory, TextQuestionFactory @pytest.mark.asyncio @@ -560,3 +562,66 @@ async def test_update_response_for_span_question_with_non_existent_label( } assert (await db.execute(select(Response).filter_by(id=response.id))).scalar_one().values == response_values + + async def test_update_response_updates_record_status_to_completed( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + response = await ResponseFactory.create(record=record, status=ResponseStatus.draft) + + resp = await async_client.put( + self.url(response.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question updated response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert resp.status_code == 200 + assert record.status == RecordStatus.completed + + async def test_update_response_updates_record_status_to_pending( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset, status=RecordStatus.completed) + response = await ResponseFactory.create( + values={ + "text-question": { + "value": "text question response", + }, + }, + record=record, + status=ResponseStatus.submitted, + ) + + resp = await async_client.put( + self.url(response.id), + headers=owner_auth_header, + json={"status": ResponseStatus.draft}, + ) + + assert resp.status_code == 200 + assert record.status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 650e9f3808..e0c9fe4d5e 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -34,11 +34,13 @@ ) from argilla_server.constants import API_KEY_HEADER_NAME from argilla_server.enums import ( + DatasetDistributionStrategy, DatasetStatus, OptionsOrder, RecordInclude, ResponseStatusFilter, SimilarityOrder, + RecordStatus, ) from argilla_server.models import ( Dataset, @@ -116,6 +118,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": None, "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_a.workspace_id), "last_activity_at": dataset_a.last_activity_at.isoformat(), "inserted_at": dataset_a.inserted_at.isoformat(), @@ -127,6 +133,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": "guidelines", "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_b.workspace_id), "last_activity_at": dataset_b.last_activity_at.isoformat(), "inserted_at": dataset_b.inserted_at.isoformat(), @@ -138,6 +148,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": None, "allow_extra_metadata": True, "status": "ready", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_c.workspace_id), "last_activity_at": dataset_c.last_activity_at.isoformat(), "inserted_at": dataset_c.inserted_at.isoformat(), @@ -653,8 +667,6 @@ async def test_list_dataset_vectors_settings_without_authentication(self, async_ assert response.status_code == 401 - # Helper function to create records with responses - async def test_get_dataset(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create(name="dataset") @@ -667,6 +679,10 @@ async def test_get_dataset(self, async_client: "AsyncClient", owner_auth_header: "guidelines": None, "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset.workspace_id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), @@ -839,13 +855,16 @@ async def test_create_dataset(self, async_client: "AsyncClient", db: "AsyncSessi await db.refresh(workspace) response_body = response.json() - assert (await db.execute(select(func.count(Dataset.id)))).scalar() == 1 assert response_body == { "id": str(UUID(response_body["id"])), "name": "name", "guidelines": "guidelines", "allow_extra_metadata": False, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(workspace.id), "last_activity_at": datetime.fromisoformat(response_body["last_activity_at"]).isoformat(), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), @@ -3644,6 +3663,7 @@ async def test_search_current_user_dataset_records( { "record": { "id": str(records[0].id), + "status": RecordStatus.pending, "fields": {"input": "input_a", "output": "output_a"}, "metadata": None, "external_id": records[0].external_id, @@ -3656,6 +3676,7 @@ async def test_search_current_user_dataset_records( { "record": { "id": str(records[1].id), + "status": RecordStatus.pending, "fields": {"input": "input_b", "output": "output_b"}, "metadata": {"unit": "test"}, "external_id": records[1].external_id, @@ -3997,6 +4018,7 @@ async def test_search_current_user_dataset_records_with_include( { "record": { "id": str(records[0].id), + "status": RecordStatus.pending, "fields": { "input": "input_a", "output": "output_a", @@ -4012,6 +4034,7 @@ async def test_search_current_user_dataset_records_with_include( { "record": { "id": str(records[1].id), + "status": RecordStatus.pending, "fields": { "input": "input_b", "output": "output_b", @@ -4151,6 +4174,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( { "record": { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -4167,6 +4191,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( { "record": { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, @@ -4182,6 +4207,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( { "record": { "id": str(record_c.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, @@ -4245,6 +4271,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors { "record": { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -4261,6 +4288,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors { "record": { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, @@ -4276,6 +4304,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors { "record": { "id": str(record_c.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, @@ -4752,6 +4781,10 @@ async def test_update_dataset(self, async_client: "AsyncClient", db: "AsyncSessi "guidelines": guidelines, "allow_extra_metadata": allow_extra_metadata, "status": "ready", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset.workspace_id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index f088cfcda9..8f78940df3 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -18,7 +18,7 @@ import pytest from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import RecordInclude, RecordSortField, ResponseStatus, UserRole +from argilla_server.enums import RecordInclude, RecordSortField, ResponseStatus, UserRole, RecordStatus from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace from argilla_server.search_engine import ( FloatMetadataFilter, @@ -821,6 +821,7 @@ async def test_list_current_user_dataset_records( "items": [ { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": {"input": "input_a", "output": "output_a"}, "metadata": None, "dataset_id": str(dataset.id), @@ -830,6 +831,7 @@ async def test_list_current_user_dataset_records( }, { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": {"input": "input_b", "output": "output_b"}, "metadata": {"unit": "test"}, "dataset_id": str(dataset.id), @@ -839,6 +841,7 @@ async def test_list_current_user_dataset_records( }, { "id": str(record_c.id), + "status": RecordStatus.pending, "fields": {"input": "input_c", "output": "output_c"}, "metadata": None, "dataset_id": str(dataset.id), @@ -898,6 +901,7 @@ async def test_list_current_user_dataset_records_with_filtered_metadata_as_annot "items": [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"input": "input_b", "output": "output_b"}, "metadata": {"key1": "value1"}, "dataset_id": str(dataset.id), diff --git a/argilla-server/tests/unit/api/handlers/v1/test_records.py b/argilla-server/tests/unit/api/handlers/v1/test_records.py index ed7d9f8cc2..3c361b1666 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_records.py @@ -19,7 +19,7 @@ import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import ResponseStatus +from argilla_server.enums import RecordStatus, ResponseStatus from argilla_server.models import Dataset, Record, Response, Suggestion, User, UserRole from argilla_server.search_engine import SearchEngine from sqlalchemy import func, select @@ -92,6 +92,7 @@ async def test_get_record(self, async_client: "AsyncClient", role: UserRole): assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -188,6 +189,7 @@ async def test_update_record(self, async_client: "AsyncClient", mock_search_engi assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": { "terms-metadata-property": "c", @@ -228,6 +230,7 @@ async def test_update_record(self, async_client: "AsyncClient", mock_search_engi "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), } + mock_search_engine.index_records.assert_called_once_with(dataset, [record]) async def test_update_record_with_null_metadata( @@ -251,6 +254,7 @@ async def test_update_record_with_null_metadata( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -278,6 +282,7 @@ async def test_update_record_with_no_metadata( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -310,6 +315,7 @@ async def test_update_record_with_list_terms_metadata( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": { "terms-metadata-property": ["a", "b", "c"], @@ -339,6 +345,7 @@ async def test_update_record_with_no_suggestions( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -1413,6 +1420,7 @@ async def test_delete_record( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": None, "external_id": record.external_id, diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index ecba3232a6..c4376ca686 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -16,7 +16,7 @@ import pytest import pytest_asyncio -from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder +from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder, RecordStatus from argilla_server.models import Dataset, Question, Record, User, VectorSettings from argilla_server.search_engine import ( FloatMetadataFilter, @@ -263,6 +263,7 @@ async def refresh_records(records: List[Record]): for record in records: await record.awaitable_attrs.suggestions await record.awaitable_attrs.responses + await record.awaitable_attrs.responses_submitted await record.awaitable_attrs.vectors @@ -314,6 +315,7 @@ async def test_create_index_for_dataset( ], "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -356,6 +358,7 @@ async def test_create_index_for_dataset_with_fields( ], "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -428,6 +431,7 @@ async def test_create_index_for_dataset_with_metadata_properties( ], "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -475,6 +479,7 @@ async def test_create_index_for_dataset_with_questions( "dynamic": "strict", "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -879,6 +884,7 @@ async def test_index_records(self, search_engine: BaseElasticAndOpenSearchEngine assert es_docs == [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), @@ -937,6 +943,7 @@ async def test_index_records_with_suggestions( assert es_docs == [ { "id": str(records[0].id), + "status": RecordStatus.pending, "fields": records[0].fields, "inserted_at": records[0].inserted_at.isoformat(), "updated_at": records[0].updated_at.isoformat(), @@ -944,6 +951,7 @@ async def test_index_records_with_suggestions( }, { "id": str(records[1].id), + "status": RecordStatus.pending, "fields": records[1].fields, "inserted_at": records[1].inserted_at.isoformat(), "updated_at": records[1].updated_at.isoformat(), @@ -978,6 +986,7 @@ async def test_index_records_with_metadata( assert es_docs == [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), @@ -1017,6 +1026,7 @@ async def test_index_records_with_vectors( assert es_docs == [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), diff --git a/argilla/src/argilla/_models/_search.py b/argilla/src/argilla/_models/_search.py index f62dbff0b7..3c256805a0 100644 --- a/argilla/src/argilla/_models/_search.py +++ b/argilla/src/argilla/_models/_search.py @@ -17,6 +17,11 @@ from pydantic import BaseModel, Field +class RecordFilterScopeModel(BaseModel): + entity: Literal["record"] = "record" + property: Literal["status"] = "status" + + class ResponseFilterScopeModel(BaseModel): """Filter scope for filtering on a response entity.""" @@ -42,6 +47,7 @@ class MetadataFilterScopeModel(BaseModel): ScopeModel = Annotated[ Union[ + RecordFilterScopeModel, ResponseFilterScopeModel, SuggestionFilterScopeModel, MetadataFilterScopeModel, diff --git a/argilla/src/argilla/records/_search.py b/argilla/src/argilla/records/_search.py index adc56b5750..6ccdcee33a 100644 --- a/argilla/src/argilla/records/_search.py +++ b/argilla/src/argilla/records/_search.py @@ -26,6 +26,7 @@ FilterModel, AndFilterModel, QueryModel, + RecordFilterScopeModel, ) @@ -54,8 +55,9 @@ def model(self) -> FilterModel: @staticmethod def _extract_filter_scope(field: str) -> ScopeModel: field = field.strip() - if field == "status": + return RecordFilterScopeModel(property="status") + elif field == "responses.status": return ResponseFilterScopeModel(property="status") elif "metadata" in field: _, md_property = field.split(".") From f084ab7026a13475c467aef6bb3fe430eb6c0f21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dami=C3=A1n=20Pumar?= Date: Thu, 4 Jul 2024 09:36:19 +0200 Subject: [PATCH 02/34] =?UTF-8?q?=E2=9C=A8=20Remove=20unused=20method?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../repositories/RecordRepository.ts | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts b/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts index 40ce2645eb..871282d9e7 100644 --- a/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts +++ b/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts @@ -43,7 +43,6 @@ export class RecordRepository { getRecords(criteria: RecordCriteria): Promise { return this.getRecordsByAdvanceSearch(criteria); - // return this.getRecordsByDatasetId(criteria); } async getRecord(recordId: string): Promise { @@ -186,35 +185,6 @@ export class RecordRepository { } } - private async getRecordsByDatasetId( - criteria: RecordCriteria - ): Promise { - const { datasetId, status, page } = criteria; - const { from, many } = page.server; - try { - const url = `/v1/me/datasets/${datasetId}/records`; - - const params = this.createParams(from, many, status); - - const { data } = await this.axios.get>( - url, - { - params, - } - ); - const { items: records, total } = data; - - return { - records, - total, - }; - } catch (err) { - throw { - response: RECORD_API_ERRORS.ERROR_FETCHING_RECORDS, - }; - } - } - private async getRecordsByAdvanceSearch( criteria: RecordCriteria ): Promise { From 6df52560973eda8f1dd7b67fa282ed5db18ea5e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 4 Jul 2024 11:09:56 +0200 Subject: [PATCH 03/34] feat: improve Records `responses_submitted` relationship to be view only (#5148) # Description Add changes to `responses_submitted` relationship to avoid problems with existent `responses` relationship and avoid a warning message that SQLAlchemy was reporting. Refs #5000 **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** - [x] Warning is not showing anymore. - [x] Test are passing. **Checklist** - I added relevant documentation - follows the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla-server/src/argilla_server/models/database.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 37bd7730c9..3230916362 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -206,8 +206,7 @@ class Record(DatabaseModel): ) responses_submitted: Mapped[List["Response"]] = relationship( back_populates="record", - cascade="all, delete-orphan", - passive_deletes=True, + viewonly=True, primaryjoin=f"and_(Record.id==Response.record_id, Response.status=='{ResponseStatus.submitted}')", order_by=Response.inserted_at.asc(), ) From cf3408c7988285b3083edd28fa9c7936370283ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 4 Jul 2024 11:56:13 +0200 Subject: [PATCH 04/34] feat: change metrics to support new distribution task logic (#5140) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR adds changes to the endpoints to get the dataset progress and current user metrics in the following way: ## `GET /datasets/:dataset_id/progress` I have changed the endpoint to support the new business logic behind the distribution task. Responding with only `completed` and `pending` type of records and using `total` as the sum of the two types of records. Old response without distribution task: ```json { "total": 8, "submitted": 2, "discarded": 2, "conflicting": 1, "pending": 3 } ``` New response with the changes from this PR supporting distribution task: * The `completed` attribute will have the count of all the records with status as `completed` for the dataset. * The `pending` attribute will have the count of all the records with status as `pending` for the dataset. * The `total` attribute will have the sum of the `completed` and `pending` attributes. ```json { "total": 5 "completed": 2, "pending": 3, } ``` @damianpumar some changes are required on the frontend to support this new endpoint structure. ## `GET /me/datasets/:dataset_id/metrics` Old response without distribution task: ```json { "records": { "count": 7 }, "responses": { "count": 4, "submitted": 1, "discarded": 2, "draft": 1 } } ``` New response with the changes from this PR supporting distribution task: * `records` section has been eliminated because is not necessary anymore. * `responses` `count` section has been renamed to `total`. * `pending` section has been added to the `responses` section. ```json { "responses": { "total": 7, "submitted": 1, "discarded": 2, "draft": 1, "pending": 3 } } ``` The logic behind these attributes is the following: * `total` is the sum of `submitted`, `discarded`, `draft` and `pending` attribute values. * `submitted` is the count of all responses belonging to the current user in the specified dataset with `submitted` status. * `discarded` is the count of all responses belonging to the current user in the specified dataset with `discarded` status. * `draft` is the count of all responses belonging to the current user in the specified dataset with `draft` status. * `pending` is the count of all records with `pending` status for the dataset that has not responses belonging to the current user. @damianpumar some changes are required on the frontend to support this new endpoint structure as well. Closes #5139 **Type of change** - Breaking change (fix or feature that would cause existing functionality to not work as expected) **How Has This Been Tested** - [x] Modifying existent tests. - [x] Running test suite with SQLite and PostgreSQL. **Checklist** - I added relevant documentation - follows the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paco Aranda Co-authored-by: Damián Pumar --- .../useDatasetProgressViewModel.ts | 22 +--- argilla-frontend/translation/de.js | 5 + argilla-frontend/translation/en.js | 4 +- .../domain/entities/dataset/Metrics.test.ts | 28 ++-- .../v1/domain/entities/dataset/Metrics.ts | 16 +-- .../v1/domain/entities/dataset/Progress.ts | 4 +- .../repositories/DatasetRepository.ts | 8 +- .../repositories/MetricsRepository.ts | 12 +- .../v1/infrastructure/types/dataset.ts | 4 +- argilla-server/CHANGELOG.md | 2 + .../api/handlers/v1/datasets/datasets.py | 18 +-- .../argilla_server/api/schemas/v1/datasets.py | 12 +- .../src/argilla_server/contexts/datasets.py | 121 +++++++++++------- .../v1/datasets/test_get_dataset_progress.py | 80 ++---------- .../v1/datasets/test_update_dataset.py | 3 +- .../unit/api/handlers/v1/test_datasets.py | 50 ++++++-- 16 files changed, 168 insertions(+), 221 deletions(-) diff --git a/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts b/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts index f2b1ef6afc..149b45ac10 100644 --- a/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts +++ b/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts @@ -22,25 +22,11 @@ export const useDatasetProgressViewModel = ({ progressRanges.value = [ { - id: "submitted", - name: t("datasets.submitted"), + id: "completed", + name: t("datasets.completed"), color: "#0508D9", - value: progress.value.submitted, - tooltip: `${progress.value.submitted}/${progress.value.total}`, - }, - { - id: "conflicting", - name: t("datasets.conflicting"), - color: "#8893c0", - value: progress.value.conflicting, - tooltip: `${progress.value.conflicting}/${progress.value.total}`, - }, - { - id: "discarded", - name: t("datasets.discarded"), - color: "#b7b7b7", - value: progress.value.discarded, - tooltip: `${progress.value.discarded}/${progress.value.total}`, + value: progress.value.completed, + tooltip: `${progress.value.completed}/${progress.value.total}`, }, { id: "pending", diff --git a/argilla-frontend/translation/de.js b/argilla-frontend/translation/de.js index 099bd2e233..8d17eb4ac9 100644 --- a/argilla-frontend/translation/de.js +++ b/argilla-frontend/translation/de.js @@ -36,6 +36,11 @@ export default { datasetSettings: "einstellungen", userSettings: "meine einstellungen", }, + datasets: { + left: "übrig", + completed: "Vollendet", + pending: "Ausstehend", + }, recordStatus: { pending: "Ausstehend", draft: "Entwurf", diff --git a/argilla-frontend/translation/en.js b/argilla-frontend/translation/en.js index af12e6df17..6ceac06d00 100644 --- a/argilla-frontend/translation/en.js +++ b/argilla-frontend/translation/en.js @@ -42,9 +42,7 @@ export default { }, datasets: { left: "left", - submitted: "Submitted", - conflicting: "Conflicting", - discarded: "Discarded", + completed: "Completed", pending: "Pending", }, recordStatus: { diff --git a/argilla-frontend/v1/domain/entities/dataset/Metrics.test.ts b/argilla-frontend/v1/domain/entities/dataset/Metrics.test.ts index 792450fe5b..322f480007 100644 --- a/argilla-frontend/v1/domain/entities/dataset/Metrics.test.ts +++ b/argilla-frontend/v1/domain/entities/dataset/Metrics.test.ts @@ -20,67 +20,67 @@ describe("Metrics", () => { describe("total", () => { it("should return the total number of records", () => { - const metrics = new Metrics(1, 0, 0, 0, 0); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.total; - expect(result).toEqual(1); + expect(result).toEqual(15); }); }); describe("responded", () => { it("should return the number of responded records", () => { - const metrics = new Metrics(5, 5, 3, 1, 1); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.responded; - expect(result).toEqual(5); + expect(result).toEqual(10); }); }); describe("pending", () => { it("should return the number of pending records", () => { - const metrics = new Metrics(5, 4, 3, 1, 0); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.pending; - expect(result).toEqual(1); + expect(result).toEqual(5); }); }); describe("progress", () => { it("should return the progress of responded records", () => { - const metrics = new Metrics(5, 4, 3, 1, 0); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.progress; - expect(result).toEqual(0.8); + expect(result).toEqual(0.6666666666666666); }); }); describe("percentage", () => { it("should return the percentage of draft records", () => { - const metrics = new Metrics(5, 4, 3, 1, 1); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.percentage.draft; - expect(result).toEqual(20); + expect(result).toEqual(6.666666666666667); }); it("should return the percentage of submitted records", () => { - const metrics = new Metrics(5, 4, 3, 1, 1); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.percentage.submitted; - expect(result).toEqual(60); + expect(result).toEqual(26.666666666666668); }); it("should return the percentage of discarded records", () => { - const metrics = new Metrics(5, 4, 3, 1, 1); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.percentage.discarded; - expect(result).toEqual(20); + expect(result).toEqual(33.333333333333336); }); }); }); diff --git a/argilla-frontend/v1/domain/entities/dataset/Metrics.ts b/argilla-frontend/v1/domain/entities/dataset/Metrics.ts index 31c80d6e08..ec1245e010 100644 --- a/argilla-frontend/v1/domain/entities/dataset/Metrics.ts +++ b/argilla-frontend/v1/domain/entities/dataset/Metrics.ts @@ -7,11 +7,11 @@ export class Metrics { }; constructor( - private readonly records: number, - public readonly responses: number, + public readonly total: number, public readonly submitted: number, public readonly discarded: number, - public readonly draft: number + public readonly draft: number, + public readonly pending: number ) { this.percentage = { pending: (this.pending * 100) / this.total, @@ -22,21 +22,13 @@ export class Metrics { } get hasMetrics() { - return this.records > 0; - } - - get total() { - return this.records; + return this.total > 0; } get responded() { return this.submitted + this.discarded + this.draft; } - get pending() { - return this.total - this.responded; - } - get progress() { return this.responded / this.total; } diff --git a/argilla-frontend/v1/domain/entities/dataset/Progress.ts b/argilla-frontend/v1/domain/entities/dataset/Progress.ts index 64c137f672..d996580c3d 100644 --- a/argilla-frontend/v1/domain/entities/dataset/Progress.ts +++ b/argilla-frontend/v1/domain/entities/dataset/Progress.ts @@ -1,9 +1,7 @@ export class Progress { constructor( public readonly total: number, - public readonly submitted: number, - public readonly discarded: number, - public readonly conflicting: number, + public readonly completed: number, public readonly pending: number ) {} } diff --git a/argilla-frontend/v1/infrastructure/repositories/DatasetRepository.ts b/argilla-frontend/v1/infrastructure/repositories/DatasetRepository.ts index 875935d9f0..fb82353fb8 100644 --- a/argilla-frontend/v1/infrastructure/repositories/DatasetRepository.ts +++ b/argilla-frontend/v1/infrastructure/repositories/DatasetRepository.ts @@ -107,13 +107,7 @@ export class DatasetRepository implements IDatasetRepository { largeCache() ); - return new Progress( - data.total, - data.submitted, - data.discarded, - data.conflicting, - data.pending - ); + return new Progress(data.total, data.completed, data.pending); } catch (err) { throw { response: DATASET_API_ERRORS.ERROR_DELETING_DATASET, diff --git a/argilla-frontend/v1/infrastructure/repositories/MetricsRepository.ts b/argilla-frontend/v1/infrastructure/repositories/MetricsRepository.ts index 7cff90f7f9..2ddc434ef7 100644 --- a/argilla-frontend/v1/infrastructure/repositories/MetricsRepository.ts +++ b/argilla-frontend/v1/infrastructure/repositories/MetricsRepository.ts @@ -3,14 +3,12 @@ import { largeCache } from "./AxiosCache"; import { Metrics } from "~/v1/domain/entities/dataset/Metrics"; interface BackendMetrics { - records: { - count: number; - }; responses: { - count: number; + total: number; submitted: number; discarded: number; draft: number; + pending: number; }; } @@ -25,11 +23,11 @@ export class MetricsRepository { ); return new Metrics( - data.records.count, - data.responses.count, + data.responses.total, data.responses.submitted, data.responses.discarded, - data.responses.draft + data.responses.draft, + data.responses.pending ); } catch { /* lint:disable:no-empty */ diff --git a/argilla-frontend/v1/infrastructure/types/dataset.ts b/argilla-frontend/v1/infrastructure/types/dataset.ts index 7b160fcbbf..e270b5495f 100644 --- a/argilla-frontend/v1/infrastructure/types/dataset.ts +++ b/argilla-frontend/v1/infrastructure/types/dataset.ts @@ -16,8 +16,6 @@ export interface BackendDatasetFeedbackTaskResponse { export interface BackendProgress { total: number; - submitted: number; - discarded: number; - conflicting: number; + completed: number; pending: number; } diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 2883fc9c6e..e466dbbded 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -24,6 +24,8 @@ These are the section headers that we use: ### Changed - Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) +- [breaking] Change `GET /datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) +- [breaking] Change `GET /me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) ### Fixed diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py index 0590b41bb4..85bf7962c8 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py @@ -147,23 +147,7 @@ async def get_current_user_dataset_metrics( await authorize(current_user, DatasetPolicy.get(dataset)) - return { - "records": { - "count": await datasets.count_records_by_dataset_id(db, dataset_id), - }, - "responses": { - "count": await datasets.count_responses_by_dataset_id_and_user_id(db, dataset_id, current_user.id), - "submitted": await datasets.count_responses_by_dataset_id_and_user_id( - db, dataset_id, current_user.id, ResponseStatus.submitted - ), - "discarded": await datasets.count_responses_by_dataset_id_and_user_id( - db, dataset_id, current_user.id, ResponseStatus.discarded - ), - "draft": await datasets.count_responses_by_dataset_id_and_user_id( - db, dataset_id, current_user.id, ResponseStatus.draft - ), - }, - } + return await datasets.get_user_dataset_metrics(db, current_user.id, dataset.id) @router.get("/datasets/{dataset_id}/progress", response_model=DatasetProgress) diff --git a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py index 1e1b69d836..dd9f1941f1 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py @@ -70,27 +70,21 @@ class DatasetOverlapDistributionUpdate(DatasetDistributionCreate): DatasetDistributionUpdate = DatasetOverlapDistributionUpdate -class RecordMetrics(BaseModel): - count: int - - class ResponseMetrics(BaseModel): - count: int + total: int submitted: int discarded: int draft: int + pending: int class DatasetMetrics(BaseModel): - records: RecordMetrics responses: ResponseMetrics class DatasetProgress(BaseModel): total: int - submitted: int - discarded: int - conflicting: int + completed: int pending: int diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 1dbf52fc53..700dfeaefa 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -37,7 +37,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload, selectinload -from argilla_server.api.schemas.v1.datasets import DatasetProgress from argilla_server.api.schemas.v1.fields import FieldCreate from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyCreate, MetadataPropertyUpdate from argilla_server.api.schemas.v1.records import ( @@ -61,7 +60,7 @@ ) from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema from argilla_server.contexts import accounts, distribution -from argilla_server.enums import DatasetStatus, RecordInclude, UserRole +from argilla_server.enums import DatasetStatus, RecordInclude, UserRole, RecordStatus from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError from argilla_server.models import ( Dataset, @@ -372,39 +371,85 @@ async def _configure_query_relationships( return query -async def count_records_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int: - return (await db.execute(select(func.count(Record.id)).filter_by(dataset_id=dataset_id))).scalar_one() - - -async def get_dataset_progress(db: AsyncSession, dataset_id: UUID) -> DatasetProgress: - submitted_case = case((Response.status == ResponseStatus.submitted, 1), else_=0) - discarded_case = case((Response.status == ResponseStatus.discarded, 1), else_=0) +async def get_user_dataset_metrics(db: AsyncSession, user_id: UUID, dataset_id: UUID) -> dict: + responses_submitted, responses_discarded, responses_draft, responses_pending = await asyncio.gather( + db.execute( + select(func.count(Response.id)) + .join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id)) + .filter( + Response.user_id == user_id, + Response.status == ResponseStatus.submitted, + ), + ), + db.execute( + select(func.count(Response.id)) + .join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id)) + .filter( + Response.user_id == user_id, + Response.status == ResponseStatus.discarded, + ), + ), + db.execute( + select(func.count(Response.id)) + .join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id)) + .filter( + Response.user_id == user_id, + Response.status == ResponseStatus.draft, + ), + ), + db.execute( + select(func.count(Record.id)) + .outerjoin(Response, and_(Response.record_id == Record.id, Response.user_id == user_id)) + .filter( + Record.dataset_id == dataset_id, + Record.status == RecordStatus.pending, + Response.id == None, + ), + ), + ) - submitted_clause = func.sum(submitted_case) > 0, func.sum(discarded_case) == 0 - discarded_clause = func.sum(discarded_case) > 0, func.sum(submitted_case) == 0 - conflicting_clause = func.sum(submitted_case) > 0, func.sum(discarded_case) > 0 + responses_submitted = responses_submitted.scalar_one() + responses_discarded = responses_discarded.scalar_one() + responses_draft = responses_draft.scalar_one() + responses_pending = responses_pending.scalar_one() + responses_total = responses_submitted + responses_discarded + responses_draft + responses_pending + + return { + "responses": { + "total": responses_total, + "submitted": responses_submitted, + "discarded": responses_discarded, + "draft": responses_draft, + "pending": responses_pending, + }, + } - query = select(Record.id).join(Response).filter(Record.dataset_id == dataset_id).group_by(Record.id) - total, submitted, discarded, conflicting = await asyncio.gather( - count_records_by_dataset_id(db, dataset_id), - db.execute(select(func.count("*")).select_from(query.having(*submitted_clause))), - db.execute(select(func.count("*")).select_from(query.having(*discarded_clause))), - db.execute(select(func.count("*")).select_from(query.having(*conflicting_clause))), +async def get_dataset_progress(db: AsyncSession, dataset_id: UUID) -> dict: + records_completed, records_pending = await asyncio.gather( + db.execute( + select(func.count(Record.id)).filter( + Record.dataset_id == dataset_id, + Record.status == RecordStatus.completed, + ), + ), + db.execute( + select(func.count(Record.id)).filter( + Record.dataset_id == dataset_id, + Record.status == RecordStatus.pending, + ), + ), ) - submitted = submitted.scalar_one() - discarded = discarded.scalar_one() - conflicting = conflicting.scalar_one() - pending = total - submitted - discarded - conflicting - - return DatasetProgress( - total=total, - submitted=submitted, - discarded=discarded, - conflicting=conflicting, - pending=pending, - ) + records_completed = records_completed.scalar_one() + records_pending = records_pending.scalar_one() + records_total = records_completed + records_pending + + return { + "total": records_total, + "completed": records_completed, + "pending": records_pending, + } _EXTRA_METADATA_FLAG = "extra" @@ -901,22 +946,6 @@ async def delete_record(db: AsyncSession, search_engine: "SearchEngine", record: return record -async def count_responses_by_dataset_id_and_user_id( - db: AsyncSession, dataset_id: UUID, user_id: UUID, response_status: Optional[ResponseStatus] = None -) -> int: - expressions = [Response.user_id == user_id] - if response_status: - expressions.append(Response.status == response_status) - - return ( - await db.execute( - select(func.count(Response.id)) - .join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id)) - .filter(*expressions) - ) - ).scalar_one() - - async def create_response( db: AsyncSession, search_engine: SearchEngine, record: Record, user: User, response_create: ResponseCreate ) -> Response: diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_get_dataset_progress.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_get_dataset_progress.py index d3cb4e7393..6fb06a06c9 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_get_dataset_progress.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_get_dataset_progress.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from uuid import UUID, uuid4 +from httpx import AsyncClient -import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import ResponseStatus, UserRole -from httpx import AsyncClient +from argilla_server.enums import UserRole, RecordStatus -from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, UserFactory +from tests.factories import DatasetFactory, RecordFactory, UserFactory @pytest.mark.asyncio @@ -30,71 +31,16 @@ def url(self, dataset_id: UUID) -> str: async def test_get_dataset_progress(self, async_client: AsyncClient, owner_auth_header: dict): dataset = await DatasetFactory.create() - record_with_one_submitted_response = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create(record=record_with_one_submitted_response) - - record_with_multiple_submitted_responses = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create_batch(3, record=record_with_multiple_submitted_responses) - - record_with_one_draft_response = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create(record=record_with_one_draft_response, status=ResponseStatus.draft) - - record_with_multiple_draft_responses = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create_batch(3, record=record_with_multiple_draft_responses, status=ResponseStatus.draft) - - record_with_one_discarded_response = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create(record=record_with_one_discarded_response, status=ResponseStatus.discarded) - - record_with_multiple_discarded_responses = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create_batch( - 3, record=record_with_multiple_discarded_responses, status=ResponseStatus.discarded - ) - - record_with_mixed_responses = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create(record=record_with_mixed_responses) - await ResponseFactory.create(record=record_with_mixed_responses, status=ResponseStatus.draft) - await ResponseFactory.create(record=record_with_mixed_responses, status=ResponseStatus.discarded) - - record_without_responses = await RecordFactory.create(dataset=dataset) - - other_dataset = await DatasetFactory.create() - - other_record_with_one_submitted_response = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create(record=other_record_with_one_submitted_response) - - other_record_with_multiple_submitted_responses = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create_batch(3, record=other_record_with_multiple_submitted_responses) - - other_record_with_one_draft_response = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create(record=other_record_with_one_draft_response, status=ResponseStatus.draft) - - other_record_with_multiple_draft_responses = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create_batch( - 3, record=other_record_with_multiple_draft_responses, status=ResponseStatus.draft - ) - - other_record_with_one_discarded_response = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create(record=other_record_with_one_discarded_response, status=ResponseStatus.discarded) - - other_record_with_multiple_discarded_responses = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create_batch( - 3, record=other_record_with_multiple_discarded_responses, status=ResponseStatus.discarded - ) - - other_record_with_mixed_responses = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create(record=other_record_with_mixed_responses) - await ResponseFactory.create(record=other_record_with_mixed_responses, status=ResponseStatus.draft) - await ResponseFactory.create(record=other_record_with_mixed_responses, status=ResponseStatus.discarded) + records_completed = await RecordFactory.create_batch(3, status=RecordStatus.completed, dataset=dataset) + records_pending = await RecordFactory.create_batch(2, status=RecordStatus.pending, dataset=dataset) response = await async_client.get(self.url(dataset.id), headers=owner_auth_header) assert response.status_code == 200 assert response.json() == { - "total": 8, - "submitted": 2, - "discarded": 2, - "conflicting": 1, - "pending": 3, + "completed": 3, + "pending": 2, + "total": 5, } async def test_get_dataset_progress_with_empty_dataset(self, async_client: AsyncClient, owner_auth_header: dict): @@ -104,11 +50,9 @@ async def test_get_dataset_progress_with_empty_dataset(self, async_client: Async assert response.status_code == 200 assert response.json() == { - "total": 0, - "submitted": 0, - "discarded": 0, - "conflicting": 0, + "completed": 0, "pending": 0, + "total": 0, } @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py index cdb9b06ea2..097bc0a1ec 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py @@ -15,9 +15,10 @@ from uuid import UUID import pytest -from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus from httpx import AsyncClient +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus + from tests.factories import DatasetFactory diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index e0c9fe4d5e..9404b3850e 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -735,11 +735,12 @@ async def test_get_current_user_dataset_metrics( self, async_client: "AsyncClient", owner: User, owner_auth_header: dict ): dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(dataset=dataset) - record_b = await RecordFactory.create(dataset=dataset) + record_a = await RecordFactory.create(dataset=dataset, status=RecordStatus.completed) + record_b = await RecordFactory.create(dataset=dataset, status=RecordStatus.completed) record_c = await RecordFactory.create(dataset=dataset) record_d = await RecordFactory.create(dataset=dataset) await RecordFactory.create_batch(3, dataset=dataset) + await RecordFactory.create_batch(2, dataset=dataset, status=RecordStatus.completed) await ResponseFactory.create(record=record_a, user=owner) await ResponseFactory.create(record=record_b, user=owner, status=ResponseStatus.discarded) await ResponseFactory.create(record=record_c, user=owner, status=ResponseStatus.discarded) @@ -758,33 +759,43 @@ async def test_get_current_user_dataset_metrics( assert response.status_code == 200 assert response.json() == { - "records": { - "count": 7, - }, "responses": { - "count": 4, + "total": 7, "submitted": 1, "discarded": 2, "draft": 1, + "pending": 3, }, } - async def test_get_current_user_dataset_metrics_without_authentication(self, async_client: "AsyncClient"): + async def test_get_current_user_dataset_metrics_with_empty_dataset( + self, async_client: "AsyncClient", owner_auth_header: dict + ): dataset = await DatasetFactory.create() - response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/metrics") + response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/metrics", headers=owner_auth_header) - assert response.status_code == 401 + assert response.status_code == 200 + assert response.json() == { + "responses": { + "total": 0, + "submitted": 0, + "discarded": 0, + "draft": 0, + "pending": 0, + }, + } @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) async def test_get_current_user_dataset_metrics_as_annotator(self, async_client: "AsyncClient", role: UserRole): dataset = await DatasetFactory.create() user = await AnnotatorFactory.create(workspaces=[dataset.workspace], role=role) record_a = await RecordFactory.create(dataset=dataset) - record_b = await RecordFactory.create(dataset=dataset) + record_b = await RecordFactory.create(dataset=dataset, status=RecordStatus.completed) record_c = await RecordFactory.create(dataset=dataset) record_d = await RecordFactory.create(dataset=dataset) await RecordFactory.create_batch(2, dataset=dataset) + await RecordFactory.create_batch(3, dataset=dataset, status=RecordStatus.completed) await ResponseFactory.create(record=record_a, user=user) await ResponseFactory.create(record=record_b, user=user) await ResponseFactory.create(record=record_c, user=user, status=ResponseStatus.discarded) @@ -800,15 +811,28 @@ async def test_get_current_user_dataset_metrics_as_annotator(self, async_client: await ResponseFactory.create(record=other_record_c, status=ResponseStatus.discarded) response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/metrics", headers={API_KEY_HEADER_NAME: user.api_key} + f"/api/v1/me/datasets/{dataset.id}/metrics", + headers={API_KEY_HEADER_NAME: user.api_key}, ) assert response.status_code == 200 assert response.json() == { - "records": {"count": 6}, - "responses": {"count": 4, "submitted": 2, "discarded": 1, "draft": 1}, + "responses": { + "total": 6, + "submitted": 2, + "discarded": 1, + "draft": 1, + "pending": 2, + }, } + async def test_get_current_user_dataset_metrics_without_authentication(self, async_client: "AsyncClient"): + dataset = await DatasetFactory.create() + + response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/metrics") + + assert response.status_code == 401 + @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) async def test_get_current_user_dataset_metrics_restricted_user_from_different_workspace( self, async_client: "AsyncClient", role: UserRole From 808c837ce812de045691620c4c97aa387458d937 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 8 Jul 2024 17:09:40 +0200 Subject: [PATCH 05/34] [ENHANCEMENT]: `argilla-server`: allow update distribution for non annotated datasets (#5171) # Description This PR changes the current validator when updating the distribution task to allow updating the distribution task settings for datasets with records without ANY response. cc @nataliaElv **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../src/argilla_server/models/database.py | 11 ++++-- .../src/argilla_server/validators/datasets.py | 10 +++--- .../v1/datasets/test_update_dataset.py | 36 +++++++++++++++---- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 3230916362..6b9580dbb5 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -17,12 +17,12 @@ from typing import Any, List, Optional, Union from uuid import UUID -from sqlalchemy import JSON, ForeignKey, String, Text, UniqueConstraint, and_, sql +from sqlalchemy import JSON, ForeignKey, String, Text, UniqueConstraint, and_, sql, select, func, text from sqlalchemy import Enum as SAEnum from sqlalchemy.engine.default import DefaultExecutionContext from sqlalchemy.ext.asyncio import async_object_session from sqlalchemy.ext.mutable import MutableDict, MutableList -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property from argilla_server.api.schemas.v1.questions import QuestionSettings from argilla_server.enums import ( @@ -361,6 +361,13 @@ class Dataset(DatabaseModel): __table_args__ = (UniqueConstraint("name", "workspace_id", name="dataset_name_workspace_id_uq"),) + @property + async def responses_count(self) -> int: + # TODO: This should be moved to proper repository + return await async_object_session(self).scalar( + select(func.count(Response.id)).join(Record).where(Record.dataset_id == self.id) + ) + @property def is_draft(self): return self.status == DatasetStatus.draft diff --git a/argilla-server/src/argilla_server/validators/datasets.py b/argilla-server/src/argilla_server/validators/datasets.py index aae2a5fc83..eb52576d41 100644 --- a/argilla-server/src/argilla_server/validators/datasets.py +++ b/argilla-server/src/argilla_server/validators/datasets.py @@ -40,9 +40,11 @@ async def _validate_name_is_not_duplicated(cls, db: AsyncSession, name: str, wor class DatasetUpdateValidator: @classmethod async def validate(cls, db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> None: - cls._validate_distribution(dataset, dataset_attrs) + await cls._validate_distribution(dataset, dataset_attrs) @classmethod - def _validate_distribution(cls, dataset: Dataset, dataset_attrs: dict) -> None: - if dataset.is_ready and dataset_attrs.get("distribution") is not None: - raise UnprocessableEntityError(f"Distribution settings cannot be modified for a published dataset") + async def _validate_distribution(cls, dataset: Dataset, dataset_attrs: dict) -> None: + if dataset_attrs.get("distribution") is not None and (await dataset.responses_count) > 0: + raise UnprocessableEntityError( + "Distribution settings cannot be modified for a dataset with records including responses" + ) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py index 097bc0a1ec..ea732d0536 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py @@ -18,8 +18,7 @@ from httpx import AsyncClient from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus - -from tests.factories import DatasetFactory +from tests.factories import DatasetFactory, RecordFactory, ResponseFactory @pytest.mark.asyncio @@ -96,7 +95,7 @@ async def test_update_dataset_without_distribution_for_published_dataset( "min_submitted": 1, } - async def test_update_dataset_distribution_for_published_dataset( + async def test_update_dataset_distribution_for_published_dataset_without_responses( self, async_client: AsyncClient, owner_auth_header: dict ): dataset = await DatasetFactory.create(status=DatasetStatus.ready) @@ -112,12 +111,37 @@ async def test_update_dataset_distribution_for_published_dataset( }, ) - assert response.status_code == 422 - assert response.json() == {"detail": "Distribution settings cannot be modified for a published dataset"} + assert response.status_code == 200 assert dataset.distribution == { "strategy": DatasetDistributionStrategy.overlap, - "min_submitted": 1, + "min_submitted": 4, + } + + async def test_update_dataset_distribution_for_dataset_with_responses( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + records = await RecordFactory.create_batch(10, dataset=dataset) + + for record in records: + await ResponseFactory.create(record=record) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + }, + ) + + assert response.status_code == 422 + + assert response.json() == { + "detail": "Distribution settings cannot be modified for a dataset with records including responses" } async def test_update_dataset_distribution_with_invalid_strategy( From 3d74a3391a481f55011c5d4dcce1f82bc1b0395c Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 11:07:04 +0200 Subject: [PATCH 06/34] chore: Add status field to record model --- argilla/src/argilla/_models/_record/_record.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/argilla/src/argilla/_models/_record/_record.py b/argilla/src/argilla/_models/_record/_record.py index 38a4996c96..05f7d0971c 100644 --- a/argilla/src/argilla/_models/_record/_record.py +++ b/argilla/src/argilla/_models/_record/_record.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Literal from pydantic import Field, field_serializer, field_validator @@ -36,6 +36,8 @@ class RecordModel(ResourceModel): responses: Optional[List[UserResponseModel]] = Field(default_factory=list) suggestions: Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]] = Field(default_factory=tuple) + status: Literal["pending", "completed"] = "pending" + external_id: Optional[Any] = None @field_serializer("external_id", when_used="unless-none") From 7b7d2f5d81cf6cda464acf1a5fb1362d56dc2425 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 11:11:57 +0200 Subject: [PATCH 07/34] feat: Add read-only property 'status' to the record resource --- argilla/src/argilla/records/_resource.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/argilla/src/argilla/records/_resource.py b/argilla/src/argilla/records/_resource.py index 53c1321b4b..0aa38b86be 100644 --- a/argilla/src/argilla/records/_resource.py +++ b/argilla/src/argilla/records/_resource.py @@ -103,7 +103,7 @@ def __init__( def __repr__(self) -> str: return ( - f"Record(id={self.id},fields={self.fields},metadata={self.metadata}," + f"Record(id={self.id},status={self.status},fields={self.fields},metadata={self.metadata}," f"suggestions={self.suggestions},responses={self.responses})" ) @@ -147,6 +147,10 @@ def metadata(self) -> "RecordMetadata": def vectors(self) -> "RecordVectors": return self.__vectors + @property + def status(self) -> "str": + return self._model.status + @property def _server_id(self) -> Optional[UUID]: return self._model.id @@ -164,6 +168,7 @@ def api_model(self) -> RecordModel: vectors=self.vectors.api_models(), responses=self.responses.api_models(), suggestions=self.suggestions.api_models(), + status=self.status, ) def serialize(self) -> Dict[str, Any]: @@ -185,6 +190,7 @@ def to_dict(self) -> Dict[str, Dict]: """ id = str(self.id) if self.id else None server_id = str(self._model.id) if self._model.id else None + status = self.status fields = self.fields.to_dict() metadata = self.metadata.to_dict() suggestions = self.suggestions.to_dict() @@ -198,6 +204,7 @@ def to_dict(self) -> Dict[str, Dict]: "suggestions": suggestions, "responses": responses, "vectors": vectors, + "status": status, "_server_id": server_id, } @@ -245,7 +252,7 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record": Returns: A Record object. """ - return cls( + instance = cls( id=model.external_id, fields=model.fields, metadata={meta.name: meta.value for meta in model.metadata}, @@ -257,10 +264,15 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record": for response in UserResponse.from_model(response_model, dataset=dataset) ], suggestions=[Suggestion.from_model(model=suggestion, dataset=dataset) for suggestion in model.suggestions], - _dataset=dataset, - _server_id=model.id, ) + # set private attributes + instance._dataset = dataset + instance._model.id = model.id + instance._model.status = model.status + + return instance + class RecordFields(dict): """This is a container class for the fields of a Record. @@ -335,7 +347,7 @@ def to_dict(self) -> Dict[str, List[Dict]]: response_dict = defaultdict(list) for response in self.__responses: response_dict[response.question_name].append({"value": response.value, "user_id": str(response.user_id)}) - return response_dict + return dict(response_dict) def api_models(self) -> List[UserResponseModel]: """Returns a list of ResponseModel objects.""" From 736bfc9a7d823145bf212a3290057be7111eb90d Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 11:13:21 +0200 Subject: [PATCH 08/34] tests: Update tests to reflect the status property --- argilla/tests/integration/test_list_records.py | 13 +++++++++++++ argilla/tests/unit/test_io/test_generic.py | 1 + argilla/tests/unit/test_io/test_hf_datasets.py | 1 + argilla/tests/unit/test_resources/test_records.py | 10 ++++++++++ 4 files changed, 25 insertions(+) diff --git a/argilla/tests/integration/test_list_records.py b/argilla/tests/integration/test_list_records.py index ec51124be5..58407bc273 100644 --- a/argilla/tests/integration/test_list_records.py +++ b/argilla/tests/integration/test_list_records.py @@ -65,6 +65,19 @@ def test_list_records_with_start_offset(client: Argilla, dataset: Dataset): records = list(dataset.records(start_offset=1)) assert len(records) == 1 + assert [record.to_dict() for record in records] == [ + { + "_server_id": str(records[0]._server_id), + "fields": {"text": "The record text field"}, + "id": "2", + "status": "pending", + "metadata": {}, + "responses": {}, + "suggestions": {}, + "vectors": {}, + } + ] + def test_list_records_with_responses(client: Argilla, dataset: Dataset): dataset.records.log( diff --git a/argilla/tests/unit/test_io/test_generic.py b/argilla/tests/unit/test_io/test_generic.py index 446693f5b5..374ee20eed 100644 --- a/argilla/tests/unit/test_io/test_generic.py +++ b/argilla/tests/unit/test_io/test_generic.py @@ -41,6 +41,7 @@ def test_to_list_flatten(self): assert records_list == [ { "id": str(record.id), + "status": "pending", "_server_id": None, "field": "The field", "key": "value", diff --git a/argilla/tests/unit/test_io/test_hf_datasets.py b/argilla/tests/unit/test_io/test_hf_datasets.py index f13ab04ef4..99e43d8caf 100644 --- a/argilla/tests/unit/test_io/test_hf_datasets.py +++ b/argilla/tests/unit/test_io/test_hf_datasets.py @@ -46,6 +46,7 @@ def test_to_datasets_with_partial_values_in_records(self): ds = HFDatasetsIO.to_datasets(records) assert ds.features == { + "status": Value(dtype="string", id=None), "_server_id": Value(dtype="null", id=None), "a": Value(dtype="string", id=None), "b": Value(dtype="string", id=None), diff --git a/argilla/tests/unit/test_resources/test_records.py b/argilla/tests/unit/test_resources/test_records.py index 09759430c7..96aa968b03 100644 --- a/argilla/tests/unit/test_resources/test_records.py +++ b/argilla/tests/unit/test_resources/test_records.py @@ -14,6 +14,8 @@ import uuid +import pytest + from argilla import Record, Suggestion, Response from argilla._models import MetadataModel @@ -31,6 +33,7 @@ def test_record_repr(self): ) assert ( record.__repr__() == f"Record(id={record_id}," + "status=pending," "fields={'name': 'John', 'age': '30'}," "metadata={'key': 'value'}," "suggestions={'question': {'value': 'answer', 'score': None, 'agent': None}}," @@ -62,3 +65,10 @@ def test_update_record_vectors(self): record.vectors["new-vector"] = [1.0, 2.0, 3.0] assert record.vectors == {"vector": [1.0, 2.0, 3.0], "new-vector": [1.0, 2.0, 3.0]} + + def test_prevent_update_record(self): + record = Record(fields={"name": "John"}) + assert record.status == "pending" + + with pytest.raises(AttributeError, match="can't set attribute 'status'"): + record.status = "completed" From f241e41acde8110046d8f4667863896ce7d0543e Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 11:18:14 +0200 Subject: [PATCH 09/34] fix: wrong filter naming after merge from develop --- argilla/src/argilla/records/_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/argilla/src/argilla/records/_search.py b/argilla/src/argilla/records/_search.py index dfa2f4c99c..07a972bde9 100644 --- a/argilla/src/argilla/records/_search.py +++ b/argilla/src/argilla/records/_search.py @@ -56,7 +56,7 @@ def _extract_filter_scope(field: str) -> ScopeModel: field = field.strip() if field == "status": return RecordFilterScopeModel(property="status") - elif field == "responses.status": + elif field == "response.status": return ResponseFilterScopeModel(property="status") elif "metadata" in field: _, md_property = field.split(".") From 9b84dcf3b1ba0aba3156c64c29b55e3f33f0c90e Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 13:19:07 +0200 Subject: [PATCH 10/34] chore: Remove message match (depends on python version --- argilla/tests/unit/test_resources/test_records.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/argilla/tests/unit/test_resources/test_records.py b/argilla/tests/unit/test_resources/test_records.py index 96aa968b03..b04adb0203 100644 --- a/argilla/tests/unit/test_resources/test_records.py +++ b/argilla/tests/unit/test_resources/test_records.py @@ -70,5 +70,5 @@ def test_prevent_update_record(self): record = Record(fields={"name": "John"}) assert record.status == "pending" - with pytest.raises(AttributeError, match="can't set attribute 'status'"): + with pytest.raises(AttributeError): record.status = "completed" From 08e5757df9dd03aa103dc71000d6efe5815493fc Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 14:23:03 +0200 Subject: [PATCH 11/34] chore: Add task distribution model --- argilla/src/argilla/_api/_datasets.py | 8 ++--- argilla/src/argilla/_models/_dataset.py | 9 ++++-- .../_models/_settings/_task_distribution.py | 29 +++++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) create mode 100644 argilla/src/argilla/_models/_settings/_task_distribution.py diff --git a/argilla/src/argilla/_api/_datasets.py b/argilla/src/argilla/_api/_datasets.py index 70504f650c..5575fa052f 100644 --- a/argilla/src/argilla/_api/_datasets.py +++ b/argilla/src/argilla/_api/_datasets.py @@ -35,7 +35,7 @@ class DatasetsAPI(ResourceAPI[DatasetModel]): @api_error_handler def create(self, dataset: "DatasetModel") -> "DatasetModel": - json_body = dataset.model_dump() + json_body = dataset.model_dump(exclude_unset=True) response = self.http_client.post( url=self.url_stub, json=json_body, @@ -48,13 +48,13 @@ def create(self, dataset: "DatasetModel") -> "DatasetModel": @api_error_handler def update(self, dataset: "DatasetModel") -> "DatasetModel": - json_body = dataset.model_dump() + json_body = dataset.model_dump(exclude_unset=True) dataset_id = json_body["id"] # type: ignore response = self.http_client.patch(f"{self.url_stub}/{dataset_id}", json=json_body) response.raise_for_status() response_json = response.json() dataset = self._model_from_json(response_json=response_json) - self._log_message(message=f"Updated dataset {dataset.url}") + self._log_message(message=f"Updated dataset {dataset.id}") return dataset @api_error_handler @@ -63,7 +63,7 @@ def get(self, dataset_id: UUID) -> "DatasetModel": response.raise_for_status() response_json = response.json() dataset = self._model_from_json(response_json=response_json) - self._log_message(message=f"Got dataset {dataset.url}") + self._log_message(message=f"Got dataset {dataset.id}") return dataset @api_error_handler diff --git a/argilla/src/argilla/_models/_dataset.py b/argilla/src/argilla/_models/_dataset.py index cd851752a5..098deff2cb 100644 --- a/argilla/src/argilla/_models/_dataset.py +++ b/argilla/src/argilla/_models/_dataset.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from datetime import datetime -from uuid import UUID from typing import Literal +from typing import Optional +from uuid import UUID from pydantic import field_serializer, ConfigDict @@ -23,6 +23,8 @@ __all__ = ["DatasetModel"] +from argilla._models._settings._task_distribution import TaskDistributionModel + class DatasetModel(ResourceModel): name: str @@ -31,9 +33,10 @@ class DatasetModel(ResourceModel): guidelines: Optional[str] = None allow_extra_metadata: bool = True # Ideally, the default value should be provided by the server + distribution: Optional[TaskDistributionModel] = None + workspace_id: Optional[UUID] = None last_activity_at: Optional[datetime] = None - url: Optional[str] = None model_config = ConfigDict( validate_assignment=True, diff --git a/argilla/src/argilla/_models/_settings/_task_distribution.py b/argilla/src/argilla/_models/_settings/_task_distribution.py new file mode 100644 index 0000000000..d44f1851c3 --- /dev/null +++ b/argilla/src/argilla/_models/_settings/_task_distribution.py @@ -0,0 +1,29 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["TaskDistributionModel", "OverlapTaskDistributionModel"] + +from typing import Literal + +from pydantic import BaseModel, PositiveInt, ConfigDict + + +class OverlapTaskDistributionModel(BaseModel): + strategy: Literal["overlap"] + min_submitted: PositiveInt + + model_config = ConfigDict(validate_assignment=True) + + +TaskDistributionModel = OverlapTaskDistributionModel From 443b9d050862051e00cd1bb2053955049deee6ae Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 14:25:28 +0200 Subject: [PATCH 12/34] feat: Add support to task distribution --- argilla/src/argilla/settings/_resource.py | 34 +++++++-- .../argilla/settings/_task_distribution.py | 70 +++++++++++++++++++ 2 files changed, 99 insertions(+), 5 deletions(-) create mode 100644 argilla/src/argilla/settings/_task_distribution.py diff --git a/argilla/src/argilla/settings/_resource.py b/argilla/src/argilla/settings/_resource.py index 7a6dc4ea86..b4bd97524f 100644 --- a/argilla/src/argilla/settings/_resource.py +++ b/argilla/src/argilla/settings/_resource.py @@ -25,12 +25,12 @@ from argilla.settings._field import TextField from argilla.settings._metadata import MetadataType, MetadataField from argilla.settings._question import QuestionType, question_from_model, question_from_dict, QuestionPropertyBase +from argilla.settings._task_distribution import OverlapTaskDistribution, DEFAULT_TASK_DISTRIBUTION, TaskDistribution from argilla.settings._vector import VectorField if TYPE_CHECKING: from argilla.datasets import Dataset - __all__ = ["Settings"] @@ -49,16 +49,21 @@ def __init__( metadata: Optional[List[MetadataType]] = None, guidelines: Optional[str] = None, allow_extra_metadata: bool = False, + distribution: Optional[TaskDistribution] = None, _dataset: Optional["Dataset"] = None, ) -> None: """ Args: fields (List[TextField]): A list of TextField objects that represent the fields in the Dataset. - questions (List[Union[LabelQuestion, MultiLabelQuestion, RankingQuestion, TextQuestion, RatingQuestion]]): A list of Question objects that represent the questions in the Dataset. + questions (List[Union[LabelQuestion, MultiLabelQuestion, RankingQuestion, TextQuestion, RatingQuestion]]): + A list of Question objects that represent the questions in the Dataset. vectors (List[VectorField]): A list of VectorField objects that represent the vectors in the Dataset. metadata (List[MetadataField]): A list of MetadataField objects that represent the metadata in the Dataset. guidelines (str): A string containing the guidelines for the Dataset. - allow_extra_metadata (bool): A boolean that determines whether or not extra metadata is allowed in the Dataset. Defaults to False. + allow_extra_metadata (bool): A boolean that determines whether or not extra metadata is allowed in the + Dataset. Defaults to False. + distribution (TaskDistribution): The annotation task distribution configuration. + Default to DEFAULT_TASK_DISTRIBUTION """ super().__init__(client=_dataset._client if _dataset else None) @@ -70,6 +75,8 @@ def __init__( self.__guidelines = self.__process_guidelines(guidelines) self.__allow_extra_metadata = allow_extra_metadata + self._distribution = distribution or DEFAULT_TASK_DISTRIBUTION + self._dataset = _dataset ##################### @@ -124,6 +131,14 @@ def allow_extra_metadata(self) -> bool: def allow_extra_metadata(self, value: bool): self.__allow_extra_metadata = value + @property + def distribution(self) -> TaskDistribution: + return self._distribution + + @distribution.setter + def distribution(self, value: TaskDistribution) -> None: + self._distribution = value + @property def dataset(self) -> "Dataset": return self._dataset @@ -168,7 +183,7 @@ def get(self) -> "Settings": self.questions = self._fetch_questions() self.vectors = self._fetch_vectors() self.metadata = self._fetch_metadata() - self.__get_dataset_related_attributes() + self.__fetch_dataset_related_attributes() self._update_last_api_call() return self @@ -218,6 +233,7 @@ def serialize(self): "vectors": self.vectors.serialize(), "metadata": self.metadata.serialize(), "allow_extra_metadata": self.allow_extra_metadata, + "distribution": self.distribution.to_dict(), } except Exception as e: raise ArgillaSerializeError(f"Failed to serialize the settings. {e.__class__.__name__}") from e @@ -246,6 +262,7 @@ def from_json(cls, path: Union[Path, str]) -> "Settings": vectors = settings_dict.get("vectors", []) metadata = settings_dict.get("metadata", []) guidelines = settings_dict.get("guidelines") + distribution = settings_dict.get("distribution") allow_extra_metadata = settings_dict.get("allow_extra_metadata") questions = [question_from_dict(question) for question in settings_dict.get("questions", [])] @@ -253,6 +270,9 @@ def from_json(cls, path: Union[Path, str]) -> "Settings": vectors = [VectorField.from_dict(vector) for vector in vectors] metadata = [MetadataField.from_dict(metadata) for metadata in metadata] + if distribution: + distribution = OverlapTaskDistribution.from_dict(distribution) + return cls( questions=questions, fields=fields, @@ -260,6 +280,7 @@ def from_json(cls, path: Union[Path, str]) -> "Settings": metadata=metadata, guidelines=guidelines, allow_extra_metadata=allow_extra_metadata, + distribution=distribution, ) def __eq__(self, other: "Settings") -> bool: @@ -272,6 +293,7 @@ def __eq__(self, other: "Settings") -> bool: def __repr__(self) -> str: return ( f"Settings(guidelines={self.guidelines}, allow_extra_metadata={self.allow_extra_metadata}, " + f"distribution={self.distribution}, " f"fields={self.fields}, questions={self.questions}, vectors={self.vectors}, metadata={self.metadata})" ) @@ -295,7 +317,7 @@ def _fetch_metadata(self) -> List[MetadataType]: models = self._client.api.metadata.list(dataset_id=self._dataset.id) return [MetadataField.from_model(model) for model in models] - def __get_dataset_related_attributes(self): + def __fetch_dataset_related_attributes(self): # This flow may be a bit weird, but it's the only way to update the dataset related attributes # Everything is point that we should have several settings-related endpoints in the API to handle this. # POST /api/v1/datasets/{dataset_id}/settings @@ -308,6 +330,7 @@ def __get_dataset_related_attributes(self): self.guidelines = dataset_model.guidelines self.allow_extra_metadata = dataset_model.allow_extra_metadata + self.distribution = TaskDistribution.from_model(dataset_model.distribution) def _update_dataset_related_attributes(self): # This flow may be a bit weird, but it's the only way to update the dataset related attributes @@ -323,6 +346,7 @@ def _update_dataset_related_attributes(self): name=self._dataset.name, guidelines=self.guidelines, allow_extra_metadata=self.allow_extra_metadata, + distribution=self.distribution._api_model(), ) self._client.api.datasets.update(dataset_model) diff --git a/argilla/src/argilla/settings/_task_distribution.py b/argilla/src/argilla/settings/_task_distribution.py new file mode 100644 index 0000000000..593df1c681 --- /dev/null +++ b/argilla/src/argilla/settings/_task_distribution.py @@ -0,0 +1,70 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from typing import Literal, Any, Dict +from typing import Dict, Any, Literal + +from argilla._models._settings._task_distribution import OverlapTaskDistributionModel + + +class OverlapTaskDistribution: + """The task distribution settings class. + + This task distribution defines a number of submitted record required to complete a record. + We could support multiple task distribution strategies in the future + + Args: + min_submitted (int): The number of min. submitted responses to complete the record + """ + + strategy: Literal["overlap"] = "overlap" + + def __init__(self, min_submitted: int): + self._model = OverlapTaskDistributionModel(min_submitted=min_submitted, strategy=self.strategy) + + def __repr__(self) -> str: + return f"OverlapTaskDistribution(min_submitted={self.min_submitted})" + + def __eq__(self, other) -> bool: + if not isinstance(other, self.__class__): + return False + + return self._model == other._model + + @property + def min_submitted(self): + return self._model.min_submitted + + @min_submitted.setter + def min_submitted(self, value: int): + self._model.min_submitted = value + + @classmethod + def from_model(cls, model: OverlapTaskDistributionModel) -> "OverlapTaskDistribution": + return cls(min_submitted=model.min_submitted) + + @classmethod + def from_dict(cls, dict: Dict[str, Any]) -> "OverlapTaskDistribution": + return cls.from_model(OverlapTaskDistributionModel.model_validate(dict)) + + def to_dict(self): + return self._model.model_dump() + + def _api_model(self) -> OverlapTaskDistributionModel: + return self._model + + +TaskDistribution = OverlapTaskDistribution + +DEFAULT_TASK_DISTRIBUTION = OverlapTaskDistribution(min_submitted=1) From 303361a7909ee10b5555108414d987c1975c7577 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 14:27:13 +0200 Subject: [PATCH 13/34] tests: Update tests with task distribution --- argilla/tests/integration/conftest.py | 7 +++++ .../tests/integration/test_create_datasets.py | 28 +++++++++++++------ .../test_update_dataset_settings.py | 15 +++++++--- .../tests/unit/test_settings/test_settings.py | 5 ++-- 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/argilla/tests/integration/conftest.py b/argilla/tests/integration/conftest.py index 2ffd41290b..4c545c961e 100644 --- a/argilla/tests/integration/conftest.py +++ b/argilla/tests/integration/conftest.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import uuid import pytest @@ -34,3 +35,9 @@ def cleanup(client: rg.Argilla): for user in client.users: if user.username.startswith("test_"): user.delete() + + +@pytest.fixture() +def dataset_name() -> str: + """use this fixture to autogenerate a safe dataset name for tests""" + return f"test_dataset_{uuid.uuid4()}" diff --git a/argilla/tests/integration/test_create_datasets.py b/argilla/tests/integration/test_create_datasets.py index 5ea3850c41..0a3fc30e95 100644 --- a/argilla/tests/integration/test_create_datasets.py +++ b/argilla/tests/integration/test_create_datasets.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import uuid import pytest -from argilla import Argilla, Dataset, Settings, TextField, RatingQuestion +from argilla import Argilla, Dataset, Settings, TextField, RatingQuestion, LabelQuestion +from argilla.settings._task_distribution import TaskDistribution @pytest.fixture(scope="session", autouse=True) @@ -28,8 +28,7 @@ def clean_datasets(client: Argilla): class TestCreateDatasets: - def test_create_dataset(self, client: Argilla): - dataset_name = f"test_dataset_{uuid.uuid4()}" + def test_create_dataset(self, client: Argilla, dataset_name: str): dataset = Dataset( name=dataset_name, settings=Settings( @@ -44,10 +43,9 @@ def test_create_dataset(self, client: Argilla): created_dataset = client.datasets(name=dataset_name) assert created_dataset.settings == dataset.settings + assert created_dataset.settings.distribution == TaskDistribution(min_submitted=1) - def test_create_multiple_dataset_with_same_settings(self, client: Argilla): - dataset_name = f"test_dataset_{uuid.uuid4()}" - + def test_create_multiple_dataset_with_same_settings(self, client: Argilla, dataset_name: str): settings = Settings( fields=[TextField(name="text")], questions=[RatingQuestion(name="question", values=[1, 2, 3, 4, 5])], @@ -70,8 +68,7 @@ def test_create_multiple_dataset_with_same_settings(self, client: Argilla): assert schema["question"].name == "question" assert schema["question"].values == [1, 2, 3, 4, 5] - def test_create_dataset_from_existing_dataset(self, client: Argilla): - dataset_name = f"test_dataset_{uuid.uuid4()}" + def test_create_dataset_from_existing_dataset(self, client: Argilla, dataset_name: str): dataset = Dataset( name=dataset_name, settings=Settings( @@ -92,3 +89,16 @@ def test_create_dataset_from_existing_dataset(self, client: Argilla): assert isinstance(schema["question"], RatingQuestion) assert schema["question"].name == "question" assert schema["question"].values == [1, 2, 3, 4, 5] + + def test_create_dataset_with_custom_task_distribution(self, client: Argilla, dataset_name: str): + task_distribution = TaskDistribution(min_submitted=4) + + settings = Settings( + fields=[TextField(name="text", title="text")], + questions=[LabelQuestion(name="label", title="text", labels=["positive", "negative"])], + distribution=task_distribution, + ) + dataset = Dataset(dataset_name, settings=settings).create() + + assert dataset.exists() + assert dataset.settings.distribution == task_distribution diff --git a/argilla/tests/integration/test_update_dataset_settings.py b/argilla/tests/integration/test_update_dataset_settings.py index 0a606481e5..06752472d3 100644 --- a/argilla/tests/integration/test_update_dataset_settings.py +++ b/argilla/tests/integration/test_update_dataset_settings.py @@ -20,9 +20,9 @@ @pytest.fixture -def dataset(): +def dataset(dataset_name: str): return Dataset( - name=f"test_dataset_{uuid.uuid4().int}", + name=dataset_name, settings=Settings( fields=[TextField(name="text", use_markdown=False)], questions=[LabelQuestion(name="label", labels=["a", "b", "c"])], @@ -34,7 +34,7 @@ class TestUpdateDatasetSettings: def test_update_settings(self, client: Argilla, dataset: Dataset): settings = dataset.settings - settings.fields.text.use_markdown = True + settings.fields["text"].use_markdown = True dataset.settings.vectors.add(VectorField(name="vector", dimensions=10)) dataset.settings.metadata.add(FloatMetadataProperty(name="metadata")) dataset.settings.update() @@ -43,10 +43,17 @@ def test_update_settings(self, client: Argilla, dataset: Dataset): settings = dataset.settings assert settings.fields["text"].use_markdown is True assert settings.vectors["vector"].dimensions == 10 - assert isinstance(settings.metadata.metadata, FloatMetadataProperty) + assert isinstance(settings.metadata["metadata"], FloatMetadataProperty) settings.vectors["vector"].title = "A new title for vector" settings.update() dataset = client.datasets(dataset.name) assert dataset.settings.vectors["vector"].title == "A new title for vector" + + def test_update_distribution_settings(self, client: Argilla, dataset: Dataset): + dataset.settings.distribution.min_submitted = 100 + dataset.update() + + dataset = client.datasets(dataset.name) + assert dataset.settings.distribution.min_submitted == 100 diff --git a/argilla/tests/unit/test_settings/test_settings.py b/argilla/tests/unit/test_settings/test_settings.py index 3d7e6d9a37..56b8cb1a58 100644 --- a/argilla/tests/unit/test_settings/test_settings.py +++ b/argilla/tests/unit/test_settings/test_settings.py @@ -77,8 +77,9 @@ def test_settings_repr(self): ) assert ( - settings.__repr__() - == f"""Settings(guidelines=None, allow_extra_metadata=False, fields={settings.fields}, questions={settings.questions}, vectors={settings.vectors}, metadata={settings.metadata})""" + settings.__repr__() == f"Settings(guidelines=None, allow_extra_metadata=False, " + "distribution=OverlapTaskDistribution(min_submitted=1), " + f"fields={settings.fields}, questions={settings.questions}, vectors={settings.vectors}, metadata={settings.metadata})" ) def test_settings_validation_with_duplicated_names(self): From 43ba10f03e90b0eb6d315e38c83e8d6458e5ea5c Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 14:38:36 +0200 Subject: [PATCH 14/34] chore: Use main TaskDistribution naning --- argilla/src/argilla/settings/_resource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/argilla/src/argilla/settings/_resource.py b/argilla/src/argilla/settings/_resource.py index b4bd97524f..8b9e811d62 100644 --- a/argilla/src/argilla/settings/_resource.py +++ b/argilla/src/argilla/settings/_resource.py @@ -25,7 +25,7 @@ from argilla.settings._field import TextField from argilla.settings._metadata import MetadataType, MetadataField from argilla.settings._question import QuestionType, question_from_model, question_from_dict, QuestionPropertyBase -from argilla.settings._task_distribution import OverlapTaskDistribution, DEFAULT_TASK_DISTRIBUTION, TaskDistribution +from argilla.settings._task_distribution import DEFAULT_TASK_DISTRIBUTION, TaskDistribution from argilla.settings._vector import VectorField if TYPE_CHECKING: @@ -271,7 +271,7 @@ def from_json(cls, path: Union[Path, str]) -> "Settings": metadata = [MetadataField.from_dict(metadata) for metadata in metadata] if distribution: - distribution = OverlapTaskDistribution.from_dict(distribution) + distribution = TaskDistribution.from_dict(distribution) return cls( questions=questions, From d6c186bdb9f0be40e548481fec4c0bf32de9f1cb Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 14:55:08 +0200 Subject: [PATCH 15/34] ci: Using feat branch docker image --- .github/workflows/argilla.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/argilla.yml b/.github/workflows/argilla.yml index fea7e140bc..94517a0fd8 100644 --- a/.github/workflows/argilla.yml +++ b/.github/workflows/argilla.yml @@ -21,7 +21,7 @@ jobs: build: services: argilla-quickstart: - image: argilla/argilla-quickstart:main + image: argilladev/argilla-quickstart:pr-5136 ports: - 6900:6900 env: From aba06c7e45460653b9214b68914dc9e242493106 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 10 Jul 2024 11:37:47 +0200 Subject: [PATCH 16/34] Update argilla/src/argilla/_models/_dataset.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Francisco Calvo --- argilla/src/argilla/_models/_dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/argilla/src/argilla/_models/_dataset.py b/argilla/src/argilla/_models/_dataset.py index 098deff2cb..32e2a5d2f8 100644 --- a/argilla/src/argilla/_models/_dataset.py +++ b/argilla/src/argilla/_models/_dataset.py @@ -32,9 +32,7 @@ class DatasetModel(ResourceModel): guidelines: Optional[str] = None allow_extra_metadata: bool = True # Ideally, the default value should be provided by the server - distribution: Optional[TaskDistributionModel] = None - workspace_id: Optional[UUID] = None last_activity_at: Optional[datetime] = None From f2238e60855b212778792dc083be4ee8eb8dec5c Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 10 Jul 2024 12:13:40 +0200 Subject: [PATCH 17/34] chore: Apply format suggestions --- argilla/src/argilla/_models/_record/_record.py | 4 +--- argilla/src/argilla/records/_resource.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/argilla/src/argilla/_models/_record/_record.py b/argilla/src/argilla/_models/_record/_record.py index 05f7d0971c..09f2e42272 100644 --- a/argilla/src/argilla/_models/_record/_record.py +++ b/argilla/src/argilla/_models/_record/_record.py @@ -30,14 +30,12 @@ class RecordModel(ResourceModel): """Schema for the records of a `Dataset`""" + status: Literal["pending", "completed"] = "pending" fields: Optional[Dict[str, FieldValue]] = None metadata: Optional[Union[List[MetadataModel], Dict[str, MetadataValue]]] = Field(default_factory=dict) vectors: Optional[List[VectorModel]] = Field(default_factory=list) responses: Optional[List[UserResponseModel]] = Field(default_factory=list) suggestions: Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]] = Field(default_factory=tuple) - - status: Literal["pending", "completed"] = "pending" - external_id: Optional[Any] = None @field_serializer("external_id", when_used="unless-none") diff --git a/argilla/src/argilla/records/_resource.py b/argilla/src/argilla/records/_resource.py index 0aa38b86be..27ec8be113 100644 --- a/argilla/src/argilla/records/_resource.py +++ b/argilla/src/argilla/records/_resource.py @@ -148,7 +148,7 @@ def vectors(self) -> "RecordVectors": return self.__vectors @property - def status(self) -> "str": + def status(self) -> str: return self._model.status @property From 2ea0a3e4e63b0f95d64ea452976d7a688177857a Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 10 Jul 2024 21:54:28 +0200 Subject: [PATCH 18/34] chore: Export distribution in dataset --- argilla/src/argilla/datasets/_resource.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/argilla/src/argilla/datasets/_resource.py b/argilla/src/argilla/datasets/_resource.py index 5237010f34..dc06c4a60e 100644 --- a/argilla/src/argilla/datasets/_resource.py +++ b/argilla/src/argilla/datasets/_resource.py @@ -24,6 +24,7 @@ from argilla.datasets._export import DiskImportExportMixin from argilla.records import DatasetRecords from argilla.settings import Settings +from argilla.settings._task_distribution import TaskDistribution from argilla.workspaces._resource import Workspace __all__ = ["Dataset"] @@ -133,6 +134,10 @@ def workspace(self) -> Workspace: self._workspace = self._resolve_workspace() return self._workspace + @property + def distribution(self) -> TaskDistribution: + return self.settings.distribution + ##################### # Core methods # ##################### @@ -205,7 +210,7 @@ def _resolve_workspace(self) -> Workspace: if not workspace.exists(): available_workspace_names = [ws.name for ws in self._client.workspaces] raise NotFoundError( - f"Workspace with name { workspace} not found. Available workspaces: {available_workspace_names}" + f"Workspace with name {workspace} not found. Available workspaces: {available_workspace_names}" ) elif isinstance(workspace, UUID): ws_model = self._client.api.workspaces.get(workspace) From bec0b0d0bf2c2d5e2f5f3de3d28df316b497e950 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Fri, 12 Jul 2024 11:13:36 +0200 Subject: [PATCH 19/34] feat: add session helper with serializable isolation level (#5165) # Description This PR add a new `get_serializable_async_db` function helper that returns a session using isolation leve as `SERIALIZABLE`. This session can be used on some handlers where we require that specific isolation level. As example I have added that session helper for handler deleting responses and PostgreSQL is showing the following received queries: ```sql 2024-07-04 17:09:40.417 CEST [83566] LOG: statement: BEGIN ISOLATION LEVEL READ COMMITTED; 2024-07-04 17:09:40.418 CEST [83566] LOG: execute __asyncpg_stmt_e__: SELECT users.first_name, users.last_name, users.username, users.role, users.api_key, users.password_hash, users.id, users.inserted_at, users.updated_at FROM users WHERE users.api_key = $1::VARCHAR 2024-07-04 17:09:40.418 CEST [83566] DETAIL: parameters: $1 = 'argilla.apikey' 2024-07-04 17:09:40.422 CEST [83566] LOG: execute __asyncpg_stmt_12__: SELECT users_1.id AS users_1_id, workspaces.name AS workspaces_name, workspaces.id AS workspaces_id, workspaces.inserted_at AS workspaces_inserted_at, workspaces.updated_at AS workspaces_updated_at FROM users AS users_1 JOIN workspaces_users AS workspaces_users_1 ON users_1.id = workspaces_users_1.user_id JOIN workspaces ON workspaces.id = workspaces_users_1.workspace_id WHERE users_1.id IN ($1::UUID) ORDER BY workspaces_users_1.inserted_at ASC 2024-07-04 17:09:40.422 CEST [83566] DETAIL: parameters: $1 = 'ed2d570f-cc9f-4d53-a433-74aa7a286a52' 2024-07-04 17:09:40.426 CEST [83566] LOG: execute __asyncpg_stmt_13__: SELECT users.first_name, users.last_name, users.username, users.role, users.api_key, users.password_hash, users.id, users.inserted_at, users.updated_at FROM users WHERE users.username = $1::VARCHAR 2024-07-04 17:09:40.426 CEST [83566] DETAIL: parameters: $1 = 'argilla' 2024-07-04 17:09:40.428 CEST [83566] LOG: execute __asyncpg_stmt_12__: SELECT users_1.id AS users_1_id, workspaces.name AS workspaces_name, workspaces.id AS workspaces_id, workspaces.inserted_at AS workspaces_inserted_at, workspaces.updated_at AS workspaces_updated_at FROM users AS users_1 JOIN workspaces_users AS workspaces_users_1 ON users_1.id = workspaces_users_1.user_id JOIN workspaces ON workspaces.id = workspaces_users_1.workspace_id WHERE users_1.id IN ($1::UUID) ORDER BY workspaces_users_1.inserted_at ASC 2024-07-04 17:09:40.428 CEST [83566] DETAIL: parameters: $1 = 'ed2d570f-cc9f-4d53-a433-74aa7a286a52' 2024-07-04 17:09:40.430 CEST [83563] LOG: statement: BEGIN ISOLATION LEVEL SERIALIZABLE; 2024-07-04 17:09:40.430 CEST [83563] LOG: execute __asyncpg_stmt_14__: SELECT responses.values, responses.status, responses.record_id, responses.user_id, responses.id, responses.inserted_at, responses.updated_at FROM responses WHERE responses.id = $1::UUID 2024-07-04 17:09:40.430 CEST [83563] DETAIL: parameters: $1 = 'fdea95a0-ee9a-43ea-b093-2e13f2473c19' 2024-07-04 17:09:40.431 CEST [83566] LOG: statement: ROLLBACK; 2024-07-04 17:09:40.432 CEST [83563] LOG: statement: ROLLBACK; ``` We can clearly see that there are two nested transaction: 1. The main one to get current user using default `get_async_db` helper. 2. A nested one using `get_serializable_async_db` (and setting `SERIALIZABLE` isolation level) trying to find the response by id. The response id used is fake so the transaction ends there and the deletion is not done. ## Missing things on this PR - [x] Fix some failing tests. - [ ] Tests are passing but still not changing the isolation level to `SERIALIZABLE`. - [ ] Check that this works as expected and does not affect SQLite. - [ ] Check that this works as expected with PostgreSQL (no concurrency errors). Closes #5155 **Type of change** - New feature (non-breaking change which adds functionality) - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** - [x] Manually seeing PostgreSQL logs. **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../argilla_server/api/handlers/v1/records.py | 4 +-- .../api/handlers/v1/responses.py | 6 ++--- .../src/argilla_server/contexts/datasets.py | 1 + argilla-server/src/argilla_server/database.py | 26 ++++++++++++++----- .../responses/upsert_responses_in_bulk.py | 6 +++-- argilla-server/tests/unit/conftest.py | 16 +++++++++--- 6 files changed, 42 insertions(+), 17 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/records.py b/argilla-server/src/argilla_server/api/handlers/v1/records.py index 3778921ee2..23398e93be 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/records.py @@ -26,7 +26,7 @@ from argilla_server.api.schemas.v1.suggestions import Suggestion as SuggestionSchema from argilla_server.api.schemas.v1.suggestions import SuggestionCreate, Suggestions from argilla_server.contexts import datasets -from argilla_server.database import get_async_db +from argilla_server.database import get_async_db, get_serializable_async_db from argilla_server.errors.future.base_errors import NotFoundError, UnprocessableEntityError from argilla_server.models import Dataset, Question, Record, Suggestion, User from argilla_server.search_engine import SearchEngine, get_search_engine @@ -88,7 +88,7 @@ async def update_record( @router.post("/records/{record_id}/responses", status_code=status.HTTP_201_CREATED, response_model=Response) async def create_record_response( *, - db: AsyncSession = Depends(get_async_db), + db: AsyncSession = Depends(get_serializable_async_db), search_engine: SearchEngine = Depends(get_search_engine), record_id: UUID, response_create: ResponseCreate, diff --git a/argilla-server/src/argilla_server/api/handlers/v1/responses.py b/argilla-server/src/argilla_server/api/handlers/v1/responses.py index ddc389563a..95e468351f 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/responses.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/responses.py @@ -28,7 +28,7 @@ ResponseUpdate, ) from argilla_server.contexts import datasets -from argilla_server.database import get_async_db +from argilla_server.database import get_serializable_async_db from argilla_server.models import Dataset, Record, Response, User from argilla_server.search_engine import SearchEngine, get_search_engine from argilla_server.security import auth @@ -55,7 +55,7 @@ async def create_current_user_responses_bulk( @router.put("/responses/{response_id}", response_model=ResponseSchema) async def update_response( *, - db: AsyncSession = Depends(get_async_db), + db: AsyncSession = Depends(get_serializable_async_db), search_engine: SearchEngine = Depends(get_search_engine), response_id: UUID, response_update: ResponseUpdate, @@ -77,7 +77,7 @@ async def update_response( @router.delete("/responses/{response_id}", response_model=ResponseSchema) async def delete_response( *, - db: AsyncSession = Depends(get_async_db), + db: AsyncSession = Depends(get_serializable_async_db), search_engine=Depends(get_search_engine), response_id: UUID, current_user: User = Security(auth.get_current_user), diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 700dfeaefa..4d5a5f89fe 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -967,6 +967,7 @@ async def create_response( ) await db.flush([response]) + await _load_users_from_responses([response]) await _touch_dataset_last_activity_at(db, record.dataset) await search_engine.update_record_response(response) await db.refresh(record, attribute_names=[Record.responses_submitted.key]) diff --git a/argilla-server/src/argilla_server/database.py b/argilla-server/src/argilla_server/database.py index e0bc4c4c95..eaaf27079b 100644 --- a/argilla-server/src/argilla_server/database.py +++ b/argilla-server/src/argilla_server/database.py @@ -14,19 +14,17 @@ import os from collections import OrderedDict from sqlite3 import Connection as SQLite3Connection -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING, AsyncGenerator, Optional from sqlalchemy import event, make_url from sqlalchemy.engine import Engine -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.engine.interfaces import IsolationLevel +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine, AsyncSession from sqlalchemy.dialects.sqlite.aiosqlite import AsyncAdapt_aiosqlite_connection import argilla_server from argilla_server.settings import settings -if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncSession - ALEMBIC_CONFIG_FILE = os.path.normpath(os.path.join(os.path.dirname(argilla_server.__file__), "alembic.ini")) TAGGED_REVISIONS = OrderedDict( @@ -55,9 +53,23 @@ def set_sqlite_pragma(dbapi_connection, connection_record): AsyncSessionLocal = async_sessionmaker(autocommit=False, expire_on_commit=False, bind=async_engine) -async def get_async_db() -> Generator["AsyncSession", None, None]: +async def get_async_db() -> AsyncGenerator[AsyncSession, None]: + async for db in _get_async_db(): + yield db + + +async def get_serializable_async_db() -> AsyncGenerator[AsyncSession, None]: + async for db in _get_async_db(isolation_level="SERIALIZABLE"): + yield db + + +async def _get_async_db(isolation_level: Optional[IsolationLevel] = None) -> AsyncGenerator[AsyncSession, None]: + db: AsyncSession = AsyncSessionLocal() + + if isolation_level is not None: + await db.connection(execution_options={"isolation_level": isolation_level}) + try: - db: "AsyncSession" = AsyncSessionLocal() yield db finally: await db.close() diff --git a/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py b/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py index 547dd7e68b..520194e46a 100644 --- a/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py +++ b/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py @@ -20,7 +20,7 @@ from argilla_server.api.policies.v1 import RecordPolicy, authorize from argilla_server.api.schemas.v1.responses import Response, ResponseBulk, ResponseBulkError, ResponseUpsert from argilla_server.contexts import datasets -from argilla_server.database import get_async_db +from argilla_server.database import get_serializable_async_db from argilla_server.errors import future as errors from argilla_server.models import User from argilla_server.search_engine import SearchEngine, get_search_engine @@ -55,6 +55,8 @@ async def execute(self, responses: List[ResponseUpsert], user: User) -> List[Res class UpsertResponsesInBulkUseCaseFactory: def __call__( - self, db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine) + self, + db: AsyncSession = Depends(get_serializable_async_db), + search_engine: SearchEngine = Depends(get_search_engine), ): return UpsertResponsesInBulkUseCase(db, search_engine) diff --git a/argilla-server/tests/unit/conftest.py b/argilla-server/tests/unit/conftest.py index fe3479ea6d..a1dac6fbc5 100644 --- a/argilla-server/tests/unit/conftest.py +++ b/argilla-server/tests/unit/conftest.py @@ -14,14 +14,15 @@ import contextlib import uuid -from typing import TYPE_CHECKING, Dict, Generator +from typing import TYPE_CHECKING, Dict, Generator, Optional import pytest import pytest_asyncio +from sqlalchemy.engine.interfaces import IsolationLevel from argilla_server import telemetry from argilla_server.api.routes import api_v1 from argilla_server.constants import API_KEY_HEADER_NAME, DEFAULT_API_KEY -from argilla_server.database import get_async_db +from argilla_server.database import get_async_db, get_serializable_async_db from argilla_server.models import User, UserRole, Workspace from argilla_server.search_engine import SearchEngine, get_search_engine from argilla_server.settings import settings @@ -78,10 +79,18 @@ async def async_client( ) -> Generator["AsyncClient", None, None]: from argilla_server import app - async def override_get_async_db(): + async def override_get_async_db(isolation_level: Optional[IsolationLevel] = None): session = TestSession() + + if isolation_level is not None: + await session.connection(execution_options={"isolation_level": isolation_level}) + yield session + async def override_get_serializable_async_db(): + async for session in override_get_async_db(isolation_level="SERIALIZABLE"): + yield session + async def override_get_search_engine(): yield mock_search_engine @@ -89,6 +98,7 @@ async def override_get_search_engine(): for api in [api_v1]: api.dependency_overrides[get_async_db] = override_get_async_db + api.dependency_overrides[get_serializable_async_db] = override_get_serializable_async_db api.dependency_overrides[get_search_engine] = override_get_search_engine async with AsyncClient(app=app, base_url="http://testserver") as async_client: From 85e847f68b1a5612d283ba39375d0631474ae639 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Fri, 12 Jul 2024 11:46:25 +0200 Subject: [PATCH 20/34] [REFACTOR] `argilla-server`: remove deprecated records endpoint (#5206) # Description This PR removes deprecated endpoints working with records to avoid creating records with a proper record status computation. The affected endpoints are: `POST /api/v1/datasets/:dataset_id/records` `PATCH /api/v1/datasets/:dataset_id/records` **Type of change** - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla-server/CHANGELOG.md | 9 +- .../api/handlers/v1/datasets/records.py | 67 ----- .../src/argilla_server/contexts/datasets.py | 155 +---------- .../test_create_dataset_records_in_bulk.py} | 6 +- .../test_update_dataset_records_in_bulk.py} | 8 +- .../unit/api/handlers/v1/test_datasets.py | 262 ++++++------------ 6 files changed, 111 insertions(+), 396 deletions(-) rename argilla-server/tests/unit/api/handlers/v1/datasets/records/{test_create_dataset_records.py => records_bulk/test_create_dataset_records_in_bulk.py} (98%) rename argilla-server/tests/unit/api/handlers/v1/datasets/records/{test_update_dataset_records.py => records_bulk/test_update_dataset_records_in_bulk.py} (97%) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 0338a4e503..64651f3ad3 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -24,13 +24,18 @@ These are the section headers that we use: ### Changed - Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) -- [breaking] Change `GET /datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) -- [breaking] Change `GET /me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) +- [breaking] Change `GET /api/v1/datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) +- [breaking] Change `GET /api/v1/me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) ### Fixed - Fixed SQLite connection settings not working correctly due to a outdated conditional. ([#5149](https://github.com/argilla-io/argilla/pull/5149)) +### Removed + +- [breaking] Remove deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) +- [breaking] Remove deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) + ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) ### Changed diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index e032aa7037..8cc5ee2538 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -34,8 +34,6 @@ RecordFilterScope, RecordIncludeParam, Records, - RecordsCreate, - RecordsUpdate, SearchRecord, SearchRecordsQuery, SearchRecordsResult, @@ -424,71 +422,6 @@ async def list_dataset_records( return Records(items=records, total=total) -@router.post( - "/datasets/{dataset_id}/records", - status_code=status.HTTP_204_NO_CONTENT, - deprecated=True, - description="Deprecated in favor of POST /datasets/{dataset_id}/records/bulk", -) -async def create_dataset_records( - *, - db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), - telemetry_client: TelemetryClient = Depends(get_telemetry_client), - dataset_id: UUID, - records_create: RecordsCreate, - current_user: User = Security(auth.get_current_user), -): - dataset = await Dataset.get_or_raise( - db, - dataset_id, - options=[ - selectinload(Dataset.fields), - selectinload(Dataset.questions), - selectinload(Dataset.metadata_properties), - selectinload(Dataset.vectors_settings), - ], - ) - - await authorize(current_user, DatasetPolicy.create_records(dataset)) - - await datasets.create_records(db, search_engine, dataset, records_create) - - telemetry_client.track_data(action="DatasetRecordsCreated", data={"records": len(records_create.items)}) - - -@router.patch( - "/datasets/{dataset_id}/records", - status_code=status.HTTP_204_NO_CONTENT, - deprecated=True, - description="Deprecated in favor of PUT /datasets/{dataset_id}/records/bulk", -) -async def update_dataset_records( - *, - db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), - telemetry_client: TelemetryClient = Depends(get_telemetry_client), - dataset_id: UUID, - records_update: RecordsUpdate, - current_user: User = Security(auth.get_current_user), -): - dataset = await Dataset.get_or_raise( - db, - dataset_id, - options=[ - selectinload(Dataset.fields), - selectinload(Dataset.questions), - selectinload(Dataset.metadata_properties), - ], - ) - - await authorize(current_user, DatasetPolicy.update_records(dataset)) - - await datasets.update_records(db, search_engine, dataset, records_update) - - telemetry_client.track_data(action="DatasetRecordsUpdated", data={"records": len(records_update.items)}) - - @router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT) async def delete_dataset_records( *, diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 4d5a5f89fe..b95fcde5e1 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -33,7 +33,7 @@ import sqlalchemy from fastapi.encoders import jsonable_encoder -from sqlalchemy import Select, and_, case, func, select +from sqlalchemy import Select, and_, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload, selectinload @@ -42,8 +42,6 @@ from argilla_server.api.schemas.v1.records import ( RecordCreate, RecordIncludeParam, - RecordsCreate, - RecordsUpdate, RecordUpdateWithId, ) from argilla_server.api.schemas.v1.responses import ( @@ -60,7 +58,7 @@ ) from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema from argilla_server.contexts import accounts, distribution -from argilla_server.enums import DatasetStatus, RecordInclude, UserRole, RecordStatus +from argilla_server.enums import DatasetStatus, UserRole, RecordStatus from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError from argilla_server.models import ( Dataset, @@ -74,7 +72,6 @@ User, Vector, VectorSettings, - Workspace, ) from argilla_server.models.suggestions import SuggestionCreateWithRecordId from argilla_server.search_engine import SearchEngine @@ -87,9 +84,6 @@ from argilla_server.validators.suggestions import SuggestionCreateValidator if TYPE_CHECKING: - from argilla_server.api.schemas.v1.datasets import ( - DatasetUpdate, - ) from argilla_server.api.schemas.v1.fields import FieldUpdate from argilla_server.api.schemas.v1.records import RecordUpdate from argilla_server.api.schemas.v1.suggestions import SuggestionCreate @@ -231,7 +225,8 @@ async def create_metadata_property( ) -> MetadataProperty: if await MetadataProperty.get_by(db, name=metadata_property_create.name, dataset_id=dataset.id): raise NotUniqueError( - f"Metadata property with name `{metadata_property_create.name}` already exists for dataset with id `{dataset.id}`" + f"Metadata property with name `{metadata_property_create.name}` already exists " + f"for dataset with id `{dataset.id}`" ) async with db.begin_nested(): @@ -292,7 +287,8 @@ async def create_vector_settings( if await VectorSettings.get_by(db, name=vector_settings_create.name, dataset_id=dataset.id): raise NotUniqueError( - f"Vector settings with name `{vector_settings_create.name}` already exists for dataset with id `{dataset.id}`" + f"Vector settings with name `{vector_settings_create.name}` already exists " + f"for dataset with id `{dataset.id}`" ) async with db.begin_nested(): @@ -403,7 +399,7 @@ async def get_user_dataset_metrics(db: AsyncSession, user_id: UUID, dataset_id: .filter( Record.dataset_id == dataset_id, Record.status == RecordStatus.pending, - Response.id == None, + Response.id == None, # noqa ), ), ) @@ -549,57 +545,6 @@ async def _build_record( ) -async def create_records( - db: AsyncSession, search_engine: SearchEngine, dataset: Dataset, records_create: RecordsCreate -): - if not dataset.is_ready: - raise UnprocessableEntityError("Records cannot be created for a non published dataset") - - records = [] - - caches = { - "users_ids_cache": set(), - "questions_cache": {}, - "metadata_properties_cache": {}, - "vectors_settings_cache": {}, - } - - for record_i, record_create in enumerate(records_create.items): - try: - record = await _build_record(db, dataset, record_create, caches) - - record.responses = await _build_record_responses( - db, record, record_create.responses, caches["users_ids_cache"] - ) - - record.suggestions = await _build_record_suggestions( - db, record, record_create.suggestions, caches["questions_cache"] - ) - - record.vectors = await _build_record_vectors( - db, - dataset, - record_create.vectors, - build_vector_func=lambda value, vector_settings_id: Vector( - value=value, vector_settings_id=vector_settings_id - ), - cache=caches["vectors_settings_cache"], - ) - - except (UnprocessableEntityError, ValueError) as e: - raise UnprocessableEntityError(f"Record at position {record_i} is not valid because {e}") from e - - records.append(record) - - async with db.begin_nested(): - db.add_all(records) - await db.flush(records) - await _preload_records_relationships_before_index(db, records) - await search_engine.index_records(dataset, records) - - await db.commit() - - async def _load_users_from_responses(responses: Union[Response, Iterable[Response]]) -> None: if isinstance(responses, Response): responses = [responses] @@ -808,92 +753,6 @@ async def preload_records_relationships_before_validate(db: AsyncSession, record ) -async def update_records( - db: AsyncSession, search_engine: "SearchEngine", dataset: Dataset, records_update: "RecordsUpdate" -) -> None: - records_ids = [record_update.id for record_update in records_update.items] - - if len(records_ids) != len(set(records_ids)): - raise UnprocessableEntityError("Found duplicate records IDs") - - existing_records_ids = await _exists_records_with_ids(db, dataset_id=dataset.id, records_ids=records_ids) - non_existing_records_ids = set(records_ids) - set(existing_records_ids) - - if len(non_existing_records_ids) > 0: - sorted_non_existing_records_ids = sorted(non_existing_records_ids, key=lambda x: records_ids.index(x)) - records_str = ", ".join([str(record_id) for record_id in sorted_non_existing_records_ids]) - raise UnprocessableEntityError(f"Found records that do not exist: {records_str}") - - # Lists to store the records that will be updated in the database or in the search engine - records_update_objects: List[Dict[str, Any]] = [] - records_search_engine_update: List[UUID] = [] - records_delete_suggestions: List[UUID] = [] - - # Cache dictionaries to avoid querying the database multiple times - caches = { - "metadata_properties": {}, - "questions": {}, - "vector_settings": {}, - } - - existing_records = await get_records_by_ids(db, records_ids=records_ids, dataset_id=dataset.id) - - suggestions = [] - upsert_vectors = [] - for record_i, (record_update, record) in enumerate(zip(records_update.items, existing_records)): - try: - params, record_suggestions, record_vectors, needs_search_engine_update, caches = await _build_record_update( - db, record, record_update, caches - ) - - if record_suggestions is not None: - suggestions.extend(record_suggestions) - records_delete_suggestions.append(record_update.id) - - upsert_vectors.extend(record_vectors) - - if needs_search_engine_update: - records_search_engine_update.append(record_update.id) - - # Only update the record if there are params to update - if len(params) > 1: - records_update_objects.append(params) - except (UnprocessableEntityError, ValueError) as e: - raise UnprocessableEntityError(f"Record at position {record_i} is not valid because {e}") from e - - async with db.begin_nested(): - if records_delete_suggestions: - params = [Suggestion.record_id.in_(records_delete_suggestions)] - await Suggestion.delete_many(db, params=params, autocommit=False) - - if suggestions: - db.add_all(suggestions) - - if upsert_vectors: - await Vector.upsert_many( - db, - objects=upsert_vectors, - constraints=[Vector.record_id, Vector.vector_settings_id], - autocommit=False, - ) - - if records_update_objects: - await Record.update_many(db, records_update_objects, autocommit=False) - - if records_search_engine_update: - records = await get_records_by_ids( - db, - dataset_id=dataset.id, - records_ids=records_search_engine_update, - include=RecordIncludeParam(keys=[RecordInclude.vectors], vectors=None), - ) - await dataset.awaitable_attrs.vectors_settings - await _preload_records_relationships_before_index(db, records) - await search_engine.index_records(dataset, records) - - await db.commit() - - async def delete_records( db: AsyncSession, search_engine: "SearchEngine", dataset: Dataset, records_ids: List[UUID] ) -> None: diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_create_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_in_bulk.py similarity index 98% rename from argilla-server/tests/unit/api/handlers/v1/datasets/records/test_create_dataset_records.py rename to argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_in_bulk.py index c13bc9b6cb..7110e9ce62 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_create_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_in_bulk.py @@ -34,9 +34,9 @@ @pytest.mark.asyncio -class TestCreateDatasetRecords: +class TestCreateDatasetRecordsInBulk: def url(self, dataset_id: UUID) -> str: - return f"/api/v1/datasets/{dataset_id}/records" + return f"/api/v1/datasets/{dataset_id}/records/bulk" async def test_create_dataset_records( self, async_client: AsyncClient, db: AsyncSession, owner: User, owner_auth_header: dict @@ -209,7 +209,7 @@ async def test_create_dataset_records( }, ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 1 assert (await db.execute(select(func.count(Response.id)))).scalar_one() == 1 diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_update_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_update_dataset_records_in_bulk.py similarity index 97% rename from argilla-server/tests/unit/api/handlers/v1/datasets/records/test_update_dataset_records.py rename to argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_update_dataset_records_in_bulk.py index cf9fa909e9..ffb7a24dc7 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_update_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_update_dataset_records_in_bulk.py @@ -35,9 +35,9 @@ @pytest.mark.asyncio -class TestUpdateDatasetRecords: +class TestUpdateDatasetRecordsInBulk: def url(self, dataset_id: UUID) -> str: - return f"/api/v1/datasets/{dataset_id}/records" + return f"/api/v1/datasets/{dataset_id}/records/bulk" async def test_update_dataset_records( self, async_client: AsyncClient, db: AsyncSession, owner: User, owner_auth_header: dict @@ -121,7 +121,7 @@ async def test_update_dataset_records( dataset=dataset, ) - response = await async_client.patch( + response = await async_client.put( self.url(dataset.id), headers=owner_auth_header, json={ @@ -180,7 +180,7 @@ async def test_update_dataset_records( }, ) - assert response.status_code == 204 + assert response.status_code == 200 assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 1 assert (await db.execute(select(func.count(Suggestion.id)))).scalar_one() == 6 diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 9404b3850e..557cb4de70 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -1807,10 +1807,10 @@ async def test_create_dataset_records( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204, response.json() + assert response.status_code == 201, response.json() assert (await db.execute(select(func.count(Record.id)))).scalar() == 5 assert (await db.execute(select(func.count(Response.id)))).scalar() == 4 assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 3 @@ -1872,13 +1872,13 @@ async def test_create_dataset_records_with_response_for_multiple_users( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) await db.refresh(annotator) await db.refresh(owner) - assert response.status_code == 204, response.json() + assert response.status_code == 201, response.json() assert (await db.execute(select(func.count(Record.id)))).scalar() == 2 assert (await db.execute(select(func.count(Response.id)).where(Response.user_id == annotator.id))).scalar() == 2 assert (await db.execute(select(func.count(Response.id)).where(Response.user_id == owner.id))).scalar() == 1 @@ -1912,7 +1912,7 @@ async def test_create_dataset_records_with_response_for_unknown_user( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422, response.json() @@ -1950,7 +1950,7 @@ async def test_create_dataset_records_with_duplicated_response_for_an_user( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422, response.json() @@ -1986,7 +1986,7 @@ async def test_create_dataset_records_with_not_valid_suggestion( question = await TextFieldFactory.create(name="input", dataset=dataset) response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={"question_id": str(question.id), **payload}, ) @@ -2020,13 +2020,10 @@ async def test_create_dataset_records_with_missing_required_fields( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 - assert response.json() == { - "detail": "Record at position 0 is not valid because missing required value for field: 'output'" - } assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 async def test_create_dataset_records_with_wrong_value_field( @@ -2054,7 +2051,7 @@ async def test_create_dataset_records_with_wrong_value_field( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2092,13 +2089,10 @@ async def test_create_dataset_records_with_extra_fields( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 - assert response.json() == { - "detail": "Record at position 0 is not valid because found fields values for non configured fields: ['output']" - } assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 @pytest.mark.parametrize( @@ -2120,10 +2114,10 @@ async def test_create_dataset_records_with_optional_fields( records_json = {"items": [record_json]} response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204, response.json() + assert response.status_code == 201, response.json() await db.refresh(dataset, attribute_names=["records"]) assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 @@ -2145,7 +2139,7 @@ async def test_create_dataset_records_with_wrong_optional_fields( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 assert response.json() == { @@ -2203,10 +2197,10 @@ async def test_create_dataset_records_with_metadata_values( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 record = (await db.execute(select(Record))).scalar() assert record.metadata_ == {"metadata-property": value} @@ -2242,7 +2236,7 @@ async def test_create_dataset_records_with_metadata_nan_values( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2280,14 +2274,10 @@ async def test_create_dataset_records_with_not_valid_metadata_values( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 - assert ( - "Record at position 0 is not valid because metadata is not valid: 'metadata-property' metadata property validation failed" - in response.json()["detail"] - ) async def test_create_dataset_records_with_extra_metadata_allowed( self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict @@ -2307,10 +2297,10 @@ async def test_create_dataset_records_with_extra_metadata_allowed( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 record = (await db.execute(select(Record))).scalar() assert record.metadata_ == {"terms-metadata": "a", "extra": {"this": {"is": "extra metadata"}}} @@ -2332,15 +2322,10 @@ async def test_create_dataset_records_with_extra_metadata_not_allowed( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 - assert ( - "Record at position 0 is not valid because metadata is not valid: 'not-defined-metadata-property' metadata" - f" property does not exists for dataset '{dataset.id}' and extra metadata is not allowed for this dataset" - == response.json()["detail"] - ) @pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin]) async def test_create_dataset_records_with_vectors( @@ -2356,7 +2341,7 @@ async def test_create_dataset_records_with_vectors( vector_settings_b = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers={API_KEY_HEADER_NAME: user.api_key}, json={ "items": [ @@ -2376,7 +2361,7 @@ async def test_create_dataset_records_with_vectors( }, ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Vector.id)))).scalar() == 3 vector_a, vector_b, vector_c = (await db.execute(select(Vector))).scalars().all() @@ -2407,7 +2392,7 @@ async def test_create_dataset_records_with_invalid_vector( vector_settings = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -2420,10 +2405,6 @@ async def test_create_dataset_records_with_invalid_vector( ) assert response.status_code == 422 - assert response.json()["detail"] == ( - f"Record at position 0 is not valid because vector with name={vector_settings.name} is not valid: " - f"vector must have {vector_settings.dimensions} elements, got 1 elements" - ) async def test_create_dataset_records_with_non_existent_vector_settings( self, async_client: "AsyncClient", owner_auth_header: dict @@ -2433,7 +2414,7 @@ async def test_create_dataset_records_with_non_existent_vector_settings( await TextQuestionFactory.create(name="text_ok", dataset=dataset) response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -2446,10 +2427,6 @@ async def test_create_dataset_records_with_non_existent_vector_settings( ) assert response.status_code == 422 - assert response.json()["detail"] == ( - "Record at position 0 is not valid because vector with name=missing_vector is not valid: " - f"vector with name=missing_vector does not exist for dataset_id={str(dataset.id)}" - ) async def test_create_dataset_records_with_vector_settings_id_from_another_dataset( self, async_client: "AsyncClient", owner_auth_header: dict @@ -2462,7 +2439,7 @@ async def test_create_dataset_records_with_vector_settings_id_from_another_datas vector_settings = await VectorSettingsFactory.create(dimensions=5) response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -2475,10 +2452,6 @@ async def test_create_dataset_records_with_vector_settings_id_from_another_datas ) assert response.status_code == 422 - assert response.json()["detail"] == ( - f"Record at position 0 is not valid because vector with name={vector_settings.name} is not valid: " - f"vector with name={vector_settings.name} does not exist for dataset_id={dataset.id}" - ) async def test_create_dataset_records_with_index_error( self, async_client: "AsyncClient", mock_search_engine: SearchEngine, db: "AsyncSession", owner_auth_header: dict @@ -2494,7 +2467,7 @@ async def test_create_dataset_records_with_index_error( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2517,7 +2490,7 @@ async def test_create_dataset_records_without_authentication(self, async_client: ], } - response = await async_client.post(f"/api/v1/datasets/{dataset.id}/records", json=records_json) + response = await async_client.post(f"/api/v1/datasets/{dataset.id}/records/bulk", json=records_json) assert response.status_code == 401 assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 @@ -2589,10 +2562,12 @@ async def test_create_dataset_records_as_admin( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: admin.api_key}, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", + headers={API_KEY_HEADER_NAME: admin.api_key}, + json=records_json, ) - assert response.status_code == 204, response.json() + assert response.status_code == 201, response.json() assert (await db.execute(select(func.count(Record.id)))).scalar() == 5 assert (await db.execute(select(func.count(Response.id)))).scalar() == 4 @@ -2623,7 +2598,7 @@ async def test_create_dataset_records_as_annotator(self, async_client: "AsyncCli } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers={API_KEY_HEADER_NAME: annotator.api_key}, json=records_json, ) @@ -2652,7 +2627,9 @@ async def test_create_dataset_records_as_admin_from_another_workspace(self, asyn } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: admin.api_key}, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", + headers={API_KEY_HEADER_NAME: admin.api_key}, + json=records_json, ) assert response.status_code == 403 @@ -2683,10 +2660,10 @@ async def test_create_dataset_records_with_submitted_response( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 assert (await db.execute(select(func.count(Response.id)))).scalar() == 1 @@ -2714,7 +2691,7 @@ async def test_create_dataset_records_with_submitted_response_without_values( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2751,10 +2728,10 @@ async def test_create_dataset_records_with_discarded_response( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 assert ( await db.execute(select(func.count(Response.id)).filter(Response.status == ResponseStatus.discarded)) @@ -2790,10 +2767,10 @@ async def test_create_dataset_records_with_draft_response( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 assert ( await db.execute(select(func.count(Response.id)).filter(Response.status == ResponseStatus.draft)) @@ -2823,7 +2800,7 @@ async def test_create_dataset_records_with_invalid_response_status( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2859,10 +2836,10 @@ async def test_create_dataset_records_with_discarded_response_without_values( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Response.id)))).scalar() == 1 assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 @@ -2877,11 +2854,10 @@ async def test_create_dataset_records_with_non_published_dataset( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 - assert response.json() == {"detail": "Records cannot be created for a non published dataset"} assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 assert (await db.execute(select(func.count(Response.id)))).scalar() == 0 @@ -2900,7 +2876,7 @@ async def test_create_dataset_records_with_less_items_than_allowed( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2922,7 +2898,7 @@ async def test_create_dataset_records_with_more_items_than_allowed( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2942,7 +2918,7 @@ async def test_create_dataset_records_with_invalid_records( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2957,7 +2933,7 @@ async def test_create_dataset_records_with_nonexistent_dataset_id( await DatasetFactory.create() response = await async_client.post( - f"/api/v1/datasets/{dataset_id}/records", + f"/api/v1/datasets/{dataset_id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -2977,7 +2953,7 @@ async def test_create_dataset_records_with_nonexistent_dataset_id( async def test_update_dataset_records( self, async_client: "AsyncClient", mock_search_engine: "SearchEngine", role: UserRole ): - dataset = await DatasetFactory.create() + dataset = await DatasetFactory.create(status=DatasetStatus.ready) user = await UserFactory.create(workspaces=[dataset.workspace], role=role) await TermsMetadataPropertyFactory.create(name="terms-metadata-property", dataset=dataset) await IntegerMetadataPropertyFactory.create(name="integer-metadata-property", dataset=dataset) @@ -2988,8 +2964,8 @@ async def test_update_dataset_records( metadata_={"terms-metadata-property": "z", "integer-metadata-property": 1, "float-metadata-property": 1.0}, ) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers={API_KEY_HEADER_NAME: user.api_key}, json={ "items": [ @@ -3027,7 +3003,7 @@ async def test_update_dataset_records( }, ) - assert response.status_code == 204 + assert response.status_code == 200, response.json() # Record 0 assert records[0].metadata_ == { @@ -3060,13 +3036,12 @@ async def test_update_dataset_records( "float-metadata-property": 1.0, } - # it should be called only with the first three records (metadata was updated for them) - mock_search_engine.index_records.assert_called_once_with(dataset, records[:3]) + mock_search_engine.index_records.assert_called_once_with(dataset, records[:4]) async def test_update_dataset_records_with_suggestions( self, async_client: "AsyncClient", mock_search_engine: "SearchEngine", owner_auth_header: dict ): - dataset = await DatasetFactory.create() + dataset = await DatasetFactory.create(status=DatasetStatus.ready) question_0 = await TextQuestionFactory.create(dataset=dataset) question_1 = await TextQuestionFactory.create(dataset=dataset) question_2 = await TextQuestionFactory.create(dataset=dataset) @@ -3093,8 +3068,8 @@ async def test_update_dataset_records_with_suggestions( await SuggestionFactory.create(question=question_2, record=records[2], value="suggestion 2 3"), ] - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3142,21 +3117,17 @@ async def test_update_dataset_records_with_suggestions( }, ) - assert response.status_code == 204 + assert response.status_code == 200 # Record 0 await records[0].awaitable_attrs.suggestions assert records[0].suggestions[0].value == "suggestion updated 0 1" assert records[0].suggestions[1].value == "suggestion updated 0 2" assert records[0].suggestions[2].value == "suggestion updated 0 3" - for suggestion in suggestions_records_0: - assert inspect(suggestion).deleted # Record 1 await records[1].awaitable_attrs.suggestions assert records[1].suggestions[0].value == "suggestion updated 1 1" - for suggestion in suggestions_records_1: - assert inspect(suggestion).deleted # Record 2 for suggestion in suggestions_records_2: @@ -3168,39 +3139,12 @@ async def test_update_dataset_records_with_suggestions( assert records[3].suggestions[1].value == "suggestion updated 3 2" assert records[3].suggestions[2].value == "suggestion updated 3 3" - mock_search_engine.index_records.assert_not_called() - - async def test_update_dataset_records_with_empty_list_of_suggestions( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - question_0 = await TextQuestionFactory.create(dataset=dataset) - question_1 = await TextQuestionFactory.create(dataset=dataset) - question_2 = await TextQuestionFactory.create(dataset=dataset) - record = await RecordFactory.create(dataset=dataset) - - suggestions_records_0 = [ - await SuggestionFactory.create(question=question_0, record=record, value="suggestion 0 1"), - await SuggestionFactory.create(question=question_1, record=record, value="suggestion 0 2"), - await SuggestionFactory.create(question=question_2, record=record, value="suggestion 0 3"), - ] - - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", - headers=owner_auth_header, - json={"items": [{"id": str(record.id), "suggestions": []}]}, - ) - - assert response.status_code == 204 - - assert await record.awaitable_attrs.suggestions == [] - for suggestion in suggestions_records_0: - assert inspect(suggestion).deleted + mock_search_engine.index_records.assert_called_once() async def test_update_dataset_records_with_vectors( self, async_client: "AsyncClient", mock_search_engine: "SearchEngine", owner_auth_header: dict ): - dataset = await DatasetFactory.create() + dataset = await DatasetFactory.create(status=DatasetStatus.ready) vector_settings_0 = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) vector_settings_1 = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) vector_settings_2 = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) @@ -3216,8 +3160,8 @@ async def test_update_dataset_records_with_vectors( await VectorFactory.create(vector_settings=vector_settings_1, record=records[1], value=[4, 4, 4, 4, 4]) await VectorFactory.create(vector_settings=vector_settings_2, record=records[1], value=[5, 5, 5, 5, 5]) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3247,7 +3191,7 @@ async def test_update_dataset_records_with_vectors( }, ) - assert response.status_code == 204 + assert response.status_code == 200 # Record 0 await records[0].awaitable_attrs.vectors @@ -3276,8 +3220,8 @@ async def test_update_dataset_records_with_invalid_metadata( await TermsMetadataPropertyFactory.create(dataset=dataset, name="terms") records = await RecordFactory.create_batch(5, dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3298,10 +3242,6 @@ async def test_update_dataset_records_with_invalid_metadata( ) assert response.status_code == 422 - assert response.json() == { - "detail": "Record at position 1 is not valid because metadata is not valid: 'terms' metadata property " - "validation failed because 'i was not declared' is not an allowed term." - } async def test_update_dataset_records_with_metadata_nan_value( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3311,8 +3251,8 @@ async def test_update_dataset_records_with_metadata_nan_value( await FloatMetadataPropertyFactory.create(dataset=dataset, name="float") records = await RecordFactory.create_batch(3, dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3341,8 +3281,8 @@ async def test_update_dataset_records_with_invalid_suggestions( question = await LabelSelectionQuestionFactory.create(dataset=dataset) records = await RecordFactory.create_batch(5, dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3356,9 +3296,6 @@ async def test_update_dataset_records_with_invalid_suggestions( ) assert response.status_code == 422 - assert response.json() == { - "detail": f"Record at position 0 is not valid because suggestion for question_id={question.id} is not valid: 'option-a' is not a valid label for label selection question.\nValid labels are: ['option1', 'option2', 'option3']" - } async def test_update_dataset_records_with_invalid_vectors( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3367,8 +3304,8 @@ async def test_update_dataset_records_with_invalid_vectors( vector_settings = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) records = await RecordFactory.create_batch(5, dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3378,18 +3315,14 @@ async def test_update_dataset_records_with_invalid_vectors( ) assert response.status_code == 422 - assert response.json() == { - "detail": f"Record at position 0 is not valid because vector with name={vector_settings.name} is not " - "valid: vector must have 5 elements, got 6 elements" - } async def test_update_dataset_records_with_nonexistent_dataset_id( self, async_client: "AsyncClient", owner_auth_header: dict ): dataset_id = uuid4() - response = await async_client.patch( - f"/api/v1/datasets/{dataset_id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset_id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3413,16 +3346,13 @@ async def test_update_dataset_records_with_nonexistent_records( records.append({"id": str(record.id), "metadata": {"i exists": True}}) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={"items": records}, ) assert response.status_code == 422 - assert response.json() == { - "detail": f"Found records that do not exist: {records[0]['id']}, {records[1]['id']}, {records[2]['id']}" - } async def test_update_dataset_records_with_nonexistent_question_id( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3432,8 +3362,8 @@ async def test_update_dataset_records_with_nonexistent_question_id( question_id = str(uuid4()) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3443,10 +3373,6 @@ async def test_update_dataset_records_with_nonexistent_question_id( ) assert response.status_code == 422 - assert response.json() == { - "detail": f"Record at position 0 is not valid because suggestion for question_id={question_id} is not " - f"valid: question_id={question_id} does not exist" - } async def test_update_dataset_records_with_nonexistent_vector_settings_name( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3454,17 +3380,13 @@ async def test_update_dataset_records_with_nonexistent_vector_settings_name( dataset = await DatasetFactory.create() record = await RecordFactory.create(dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={"items": [{"id": str(record.id), "vectors": {"i-do-not-exist": [1, 2, 3, 4]}}]}, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Record at position 0 is not valid because vector with name=i-do-not-exist is not valid: vector " - f"with name=i-do-not-exist does not exist for dataset_id={dataset.id}" - } async def test_update_dataset_records_with_duplicate_records_ids( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3472,8 +3394,8 @@ async def test_update_dataset_records_with_duplicate_records_ids( dataset = await DatasetFactory.create() record = await RecordFactory.create(dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3484,7 +3406,6 @@ async def test_update_dataset_records_with_duplicate_records_ids( ) assert response.status_code == 422 - assert response.json() == {"detail": "Found duplicate records IDs"} async def test_update_dataset_records_with_duplicate_suggestions_question_ids( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3493,8 +3414,8 @@ async def test_update_dataset_records_with_duplicate_suggestions_question_ids( question = await TextQuestionFactory.create(dataset=dataset) record = await RecordFactory.create(dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3510,16 +3431,13 @@ async def test_update_dataset_records_with_duplicate_suggestions_question_ids( ) assert response.status_code == 422 - assert response.json() == { - "detail": "Record at position 0 is not valid because found duplicate suggestions question IDs" - } async def test_update_dataset_records_as_admin_from_another_workspace(self, async_client: "AsyncClient"): dataset = await DatasetFactory.create() user = await UserFactory.create(role=UserRole.admin) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers={API_KEY_HEADER_NAME: user.api_key}, json={ "items": [ @@ -3536,8 +3454,8 @@ async def test_update_dataset_records_as_annotator(self, async_client: "AsyncCli dataset = await DatasetFactory.create() user = await UserFactory.create(role=UserRole.annotator, workspaces=[dataset.workspace]) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers={API_KEY_HEADER_NAME: user.api_key}, json={ "items": [ @@ -3553,8 +3471,8 @@ async def test_update_dataset_records_as_annotator(self, async_client: "AsyncCli async def test_update_dataset_records_without_authentication(self, async_client: "AsyncClient"): dataset = await DatasetFactory.create() - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", json={"items": [{"id": str(uuid4())}]} + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", json={"items": [{"id": str(uuid4())}]} ) assert response.status_code == 401 From 104148725082ccc53af923863d4a83288d72aaee Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 12 Jul 2024 12:09:29 +0200 Subject: [PATCH 21/34] chore: Add task distribution setter for dataset --- argilla/src/argilla/datasets/_resource.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/argilla/src/argilla/datasets/_resource.py b/argilla/src/argilla/datasets/_resource.py index dc06c4a60e..6df656f737 100644 --- a/argilla/src/argilla/datasets/_resource.py +++ b/argilla/src/argilla/datasets/_resource.py @@ -138,6 +138,10 @@ def workspace(self) -> Workspace: def distribution(self) -> TaskDistribution: return self.settings.distribution + @distribution.setter + def distribution(self, value: TaskDistribution) -> None: + self.settings.distribution = value + ##################### # Core methods # ##################### From c219764e41600c1b9a14f80a1a27aa0e43e129cc Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Fri, 12 Jul 2024 16:56:26 +0200 Subject: [PATCH 22/34] [ENHANCEMENT] `argilla`: add record `status` property (#5184) # Description This PR adds the record status as a read-only property in the `Record` resource class. Closes https://github.com/argilla-io/argilla/issues/5141 **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../src/argilla/_models/_record/_record.py | 4 ++-- argilla/src/argilla/records/_resource.py | 22 ++++++++++++++----- .../tests/integration/test_list_records.py | 13 +++++++++++ argilla/tests/unit/test_io/test_generic.py | 1 + .../tests/unit/test_io/test_hf_datasets.py | 1 + .../tests/unit/test_resources/test_records.py | 10 +++++++++ 6 files changed, 44 insertions(+), 7 deletions(-) diff --git a/argilla/src/argilla/_models/_record/_record.py b/argilla/src/argilla/_models/_record/_record.py index 38a4996c96..09f2e42272 100644 --- a/argilla/src/argilla/_models/_record/_record.py +++ b/argilla/src/argilla/_models/_record/_record.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Literal from pydantic import Field, field_serializer, field_validator @@ -30,12 +30,12 @@ class RecordModel(ResourceModel): """Schema for the records of a `Dataset`""" + status: Literal["pending", "completed"] = "pending" fields: Optional[Dict[str, FieldValue]] = None metadata: Optional[Union[List[MetadataModel], Dict[str, MetadataValue]]] = Field(default_factory=dict) vectors: Optional[List[VectorModel]] = Field(default_factory=list) responses: Optional[List[UserResponseModel]] = Field(default_factory=list) suggestions: Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]] = Field(default_factory=tuple) - external_id: Optional[Any] = None @field_serializer("external_id", when_used="unless-none") diff --git a/argilla/src/argilla/records/_resource.py b/argilla/src/argilla/records/_resource.py index 53c1321b4b..27ec8be113 100644 --- a/argilla/src/argilla/records/_resource.py +++ b/argilla/src/argilla/records/_resource.py @@ -103,7 +103,7 @@ def __init__( def __repr__(self) -> str: return ( - f"Record(id={self.id},fields={self.fields},metadata={self.metadata}," + f"Record(id={self.id},status={self.status},fields={self.fields},metadata={self.metadata}," f"suggestions={self.suggestions},responses={self.responses})" ) @@ -147,6 +147,10 @@ def metadata(self) -> "RecordMetadata": def vectors(self) -> "RecordVectors": return self.__vectors + @property + def status(self) -> str: + return self._model.status + @property def _server_id(self) -> Optional[UUID]: return self._model.id @@ -164,6 +168,7 @@ def api_model(self) -> RecordModel: vectors=self.vectors.api_models(), responses=self.responses.api_models(), suggestions=self.suggestions.api_models(), + status=self.status, ) def serialize(self) -> Dict[str, Any]: @@ -185,6 +190,7 @@ def to_dict(self) -> Dict[str, Dict]: """ id = str(self.id) if self.id else None server_id = str(self._model.id) if self._model.id else None + status = self.status fields = self.fields.to_dict() metadata = self.metadata.to_dict() suggestions = self.suggestions.to_dict() @@ -198,6 +204,7 @@ def to_dict(self) -> Dict[str, Dict]: "suggestions": suggestions, "responses": responses, "vectors": vectors, + "status": status, "_server_id": server_id, } @@ -245,7 +252,7 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record": Returns: A Record object. """ - return cls( + instance = cls( id=model.external_id, fields=model.fields, metadata={meta.name: meta.value for meta in model.metadata}, @@ -257,10 +264,15 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record": for response in UserResponse.from_model(response_model, dataset=dataset) ], suggestions=[Suggestion.from_model(model=suggestion, dataset=dataset) for suggestion in model.suggestions], - _dataset=dataset, - _server_id=model.id, ) + # set private attributes + instance._dataset = dataset + instance._model.id = model.id + instance._model.status = model.status + + return instance + class RecordFields(dict): """This is a container class for the fields of a Record. @@ -335,7 +347,7 @@ def to_dict(self) -> Dict[str, List[Dict]]: response_dict = defaultdict(list) for response in self.__responses: response_dict[response.question_name].append({"value": response.value, "user_id": str(response.user_id)}) - return response_dict + return dict(response_dict) def api_models(self) -> List[UserResponseModel]: """Returns a list of ResponseModel objects.""" diff --git a/argilla/tests/integration/test_list_records.py b/argilla/tests/integration/test_list_records.py index ec51124be5..58407bc273 100644 --- a/argilla/tests/integration/test_list_records.py +++ b/argilla/tests/integration/test_list_records.py @@ -65,6 +65,19 @@ def test_list_records_with_start_offset(client: Argilla, dataset: Dataset): records = list(dataset.records(start_offset=1)) assert len(records) == 1 + assert [record.to_dict() for record in records] == [ + { + "_server_id": str(records[0]._server_id), + "fields": {"text": "The record text field"}, + "id": "2", + "status": "pending", + "metadata": {}, + "responses": {}, + "suggestions": {}, + "vectors": {}, + } + ] + def test_list_records_with_responses(client: Argilla, dataset: Dataset): dataset.records.log( diff --git a/argilla/tests/unit/test_io/test_generic.py b/argilla/tests/unit/test_io/test_generic.py index 446693f5b5..374ee20eed 100644 --- a/argilla/tests/unit/test_io/test_generic.py +++ b/argilla/tests/unit/test_io/test_generic.py @@ -41,6 +41,7 @@ def test_to_list_flatten(self): assert records_list == [ { "id": str(record.id), + "status": "pending", "_server_id": None, "field": "The field", "key": "value", diff --git a/argilla/tests/unit/test_io/test_hf_datasets.py b/argilla/tests/unit/test_io/test_hf_datasets.py index f13ab04ef4..99e43d8caf 100644 --- a/argilla/tests/unit/test_io/test_hf_datasets.py +++ b/argilla/tests/unit/test_io/test_hf_datasets.py @@ -46,6 +46,7 @@ def test_to_datasets_with_partial_values_in_records(self): ds = HFDatasetsIO.to_datasets(records) assert ds.features == { + "status": Value(dtype="string", id=None), "_server_id": Value(dtype="null", id=None), "a": Value(dtype="string", id=None), "b": Value(dtype="string", id=None), diff --git a/argilla/tests/unit/test_resources/test_records.py b/argilla/tests/unit/test_resources/test_records.py index 09759430c7..b04adb0203 100644 --- a/argilla/tests/unit/test_resources/test_records.py +++ b/argilla/tests/unit/test_resources/test_records.py @@ -14,6 +14,8 @@ import uuid +import pytest + from argilla import Record, Suggestion, Response from argilla._models import MetadataModel @@ -31,6 +33,7 @@ def test_record_repr(self): ) assert ( record.__repr__() == f"Record(id={record_id}," + "status=pending," "fields={'name': 'John', 'age': '30'}," "metadata={'key': 'value'}," "suggestions={'question': {'value': 'answer', 'score': None, 'agent': None}}," @@ -62,3 +65,10 @@ def test_update_record_vectors(self): record.vectors["new-vector"] = [1.0, 2.0, 3.0] assert record.vectors == {"vector": [1.0, 2.0, 3.0], "new-vector": [1.0, 2.0, 3.0]} + + def test_prevent_update_record(self): + record = Record(fields={"name": "John"}) + assert record.status == "pending" + + with pytest.raises(AttributeError): + record.status = "completed" From a9375c15b164e9a5cff864dfcbfe351fb9d2c592 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 15 Jul 2024 11:04:16 +0200 Subject: [PATCH 23/34] [REFACTOR] cleaning list records endpoints (#5221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Pull Request Template This PR merges all the approved PRs related to cleaning list and search records endpoints - https://github.com/argilla-io/argilla/pull/5153 - https://github.com/argilla-io/argilla/pull/5156 - https://github.com/argilla-io/argilla/pull/5163 - https://github.com/argilla-io/argilla/pull/5166 **Type of change** - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: José Francisco Calvo --- argilla-server/CHANGELOG.md | 7 + .../api/handlers/v1/datasets/records.py | 168 +-- .../argilla_server/api/schemas/v1/records.py | 12 - .../src/argilla_server/search_engine/base.py | 85 +- .../argilla_server/search_engine/commons.py | 90 +- .../datasets/test_search_dataset_records.py | 6 - .../unit/api/handlers/v1/test_datasets.py | 317 +++--- .../handlers/v1/test_list_dataset_records.py | 991 +----------------- .../tests/unit/search_engine/test_commons.py | 104 +- 9 files changed, 218 insertions(+), 1562 deletions(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 811497a370..3b7a833ff9 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -39,6 +39,13 @@ These are the section headers that we use: - [breaking] Remove deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) - [breaking] Remove deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) +### Removed + +- Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) +- [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) +- [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) +- [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166)) + ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) ### Changed diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 8cc5ee2538..0fca256da4 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -19,7 +19,6 @@ from fastapi import APIRouter, Depends, Query, Security, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from typing_extensions import Annotated import argilla_server.search_engine as search_engine from argilla_server.api.policies.v1 import DatasetPolicy, RecordPolicy, authorize, is_authorized @@ -27,8 +26,6 @@ Filters, FilterScope, MetadataFilterScope, - MetadataParsedQueryParam, - MetadataQueryParams, Order, RangeFilter, RecordFilterScope, @@ -49,19 +46,14 @@ ) from argilla_server.contexts import datasets, search from argilla_server.database import get_async_db -from argilla_server.enums import MetadataPropertyType, RecordSortField, ResponseStatusFilter, SortOrder +from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE -from argilla_server.models import Dataset, Field, MetadataProperty, Record, User, VectorSettings +from argilla_server.models import Dataset, Field, Record, User, VectorSettings from argilla_server.search_engine import ( AndFilter, - FloatMetadataFilter, - IntegerMetadataFilter, - MetadataFilter, SearchEngine, SearchResponses, - SortBy, - TermsMetadataFilter, UserResponseStatusFilter, get_search_engine, ) @@ -74,25 +66,6 @@ LIST_DATASET_RECORDS_DEFAULT_SORT_BY = {RecordSortField.inserted_at.value: "asc"} DELETE_DATASET_RECORDS_LIMIT = 100 -_RECORD_SORT_FIELD_VALUES = tuple(field.value for field in RecordSortField) -_VALID_SORT_VALUES = tuple(sort.value for sort in SortOrder) -_METADATA_PROPERTY_SORT_BY_REGEX = re.compile(r"^metadata\.(?P(?=.*[a-z0-9])[a-z0-9_-]+)$") - -SortByQueryParamParsed = Annotated[ - Dict[str, str], - Depends( - parse_query_param( - name="sort_by", - description=( - "The field used to sort the records. Expected format is `field` or `field:{asc,desc}`, where `field`" - " can be 'inserted_at', 'updated_at' or the name of a metadata property" - ), - max_values_per_key=1, - group_keys_without_values=False, - ) - ), -] - parse_record_include_param = parse_query_param( name="include", help="Relationships to include in the response", model=RecordIncludeParam ) @@ -104,13 +77,10 @@ async def _filter_records_using_search_engine( db: "AsyncSession", search_engine: "SearchEngine", dataset: Dataset, - parsed_metadata: List[MetadataParsedQueryParam], limit: int, offset: int, user: Optional[User] = None, - response_statuses: Optional[List[ResponseStatusFilter]] = None, include: Optional[RecordIncludeParam] = None, - sort_by_query_param: Optional[Dict[str, str]] = None, ) -> Tuple[List[Record], int]: search_responses = await _get_search_responses( db=db, @@ -119,9 +89,6 @@ async def _filter_records_using_search_engine( limit=limit, offset=offset, user=user, - parsed_metadata=parsed_metadata, - response_statuses=response_statuses, - sort_by_query_param=sort_by_query_param, ) record_ids = [response.record_id for response in search_responses.items] @@ -180,13 +147,10 @@ async def _get_search_responses( db: "AsyncSession", search_engine: "SearchEngine", dataset: Dataset, - parsed_metadata: List[MetadataParsedQueryParam], limit: int, offset: int, search_records_query: Optional[SearchRecordsQuery] = None, user: Optional[User] = None, - response_statuses: Optional[List[ResponseStatusFilter]] = None, - sort_by_query_param: Optional[Dict[str, str]] = None, ) -> "SearchResponses": search_records_query = search_records_query or SearchRecordsQuery() @@ -226,10 +190,6 @@ async def _get_search_responses( if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id): raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.") - metadata_filters = await _build_metadata_filters(db, dataset, parsed_metadata) - response_status_filter = await _build_response_status_filter_for_search(response_statuses, user=user) - sort_by = await _build_sort_by(db, dataset, sort_by_query_param) - if vector_query and vector_settings: similarity_search_params = { "dataset": dataset, @@ -238,8 +198,6 @@ async def _get_search_responses( "record": record, "query": text_query, "order": vector_query.order, - "metadata_filters": metadata_filters, - "user_response_status_filter": response_status_filter, "max_results": limit, } @@ -251,11 +209,8 @@ async def _get_search_responses( search_params = { "dataset": dataset, "query": text_query, - "metadata_filters": metadata_filters, - "user_response_status_filter": response_status_filter, "offset": offset, "limit": limit, - "sort_by": sort_by, } if user is not None: @@ -269,32 +224,6 @@ async def _get_search_responses( return await search_engine.search(**search_params) -async def _build_metadata_filters( - db: "AsyncSession", dataset: Dataset, parsed_metadata: List[MetadataParsedQueryParam] -) -> List["MetadataFilter"]: - try: - metadata_filters = [] - for metadata_param in parsed_metadata: - metadata_property = await MetadataProperty.get_by(db, name=metadata_param.name, dataset_id=dataset.id) - if metadata_property is None: - continue # won't fail on unknown metadata filter name - - if metadata_property.type == MetadataPropertyType.terms: - metadata_filter_class = TermsMetadataFilter - elif metadata_property.type == MetadataPropertyType.integer: - metadata_filter_class = IntegerMetadataFilter - elif metadata_property.type == MetadataPropertyType.float: - metadata_filter_class = FloatMetadataFilter - else: - raise ValueError(f"Not found filter for type {metadata_property.type}") - - metadata_filters.append(metadata_filter_class.from_string(metadata_property, metadata_param.value)) - except (UnprocessableEntityError, ValueError) as ex: - raise UnprocessableEntityError(f"Cannot parse provided metadata filters: {ex}") - - return metadata_filters - - async def _build_response_status_filter_for_search( response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None ) -> Optional[UserResponseStatusFilter]: @@ -307,43 +236,6 @@ async def _build_response_status_filter_for_search( return user_response_status_filter -async def _build_sort_by( - db: "AsyncSession", dataset: Dataset, sort_by_query_param: Optional[Dict[str, str]] = None -) -> Union[List[SortBy], None]: - if sort_by_query_param is None: - return None - - sorts_by = [] - for sort_field, sort_order in sort_by_query_param.items(): - if sort_field in _RECORD_SORT_FIELD_VALUES: - field = sort_field - elif (match := _METADATA_PROPERTY_SORT_BY_REGEX.match(sort_field)) is not None: - metadata_property_name = match.group("name") - metadata_property = await MetadataProperty.get_by(db, name=metadata_property_name, dataset_id=dataset.id) - if not metadata_property: - raise UnprocessableEntityError( - f"Provided metadata property in 'sort_by' query param '{metadata_property_name}' not found in " - f"dataset with '{dataset.id}'." - ) - - field = metadata_property - else: - valid_sort_fields = ", ".join(f"'{sort_field}'" for sort_field in _RECORD_SORT_FIELD_VALUES) - raise UnprocessableEntityError( - f"Provided sort field in 'sort_by' query param '{sort_field}' is not valid. It must be either" - f" {valid_sort_fields} or `metadata.metadata-property-name`" - ) - - if sort_order is not None and sort_order not in _VALID_SORT_VALUES: - raise UnprocessableEntityError( - f"Provided sort order in 'sort_by' query param '{sort_order}' for field '{sort_field}' is not valid.", - ) - - sorts_by.append(SortBy(field=field, order=sort_order or SortOrder.asc.value)) - - return sorts_by - - async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID): try: await search.validate_search_records_query(db, query, dataset_id) @@ -351,54 +243,13 @@ async def _validate_search_records_query(db: "AsyncSession", query: SearchRecord raise UnprocessableEntityError(str(e)) -@router.get("/me/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) -async def list_current_user_dataset_records( - *, - db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), - dataset_id: UUID, - metadata: MetadataQueryParams = Depends(), - sort_by_query_param: SortByQueryParamParsed, - include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), - offset: int = 0, - limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), - current_user: User = Security(auth.get_current_user), -): - dataset = await Dataset.get_or_raise(db, dataset_id, options=[selectinload(Dataset.metadata_properties)]) - - await authorize(current_user, DatasetPolicy.get(dataset)) - - records, total = await _filter_records_using_search_engine( - db, - search_engine, - dataset=dataset, - parsed_metadata=metadata.metadata_parsed, - limit=limit, - offset=offset, - user=current_user, - response_statuses=response_statuses, - include=include, - sort_by_query_param=sort_by_query_param, - ) - - for record in records: - record.dataset = dataset - record.metadata_ = await _filter_record_metadata_for_user(record, current_user) - - return Records(items=records, total=total) - - @router.get("/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) async def list_dataset_records( *, db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, - metadata: MetadataQueryParams = Depends(), - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), @@ -411,12 +262,9 @@ async def list_dataset_records( db, search_engine, dataset=dataset, - parsed_metadata=metadata.metadata_parsed, limit=limit, offset=offset, - response_statuses=response_statuses, include=include, - sort_by_query_param=sort_by_query_param or LIST_DATASET_RECORDS_DEFAULT_SORT_BY, ) return Records(items=records, total=total) @@ -460,10 +308,7 @@ async def search_current_user_dataset_records( telemetry_client: TelemetryClient = Depends(get_telemetry_client), dataset_id: UUID, body: SearchRecordsQuery, - metadata: MetadataQueryParams = Depends(), - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), @@ -486,12 +331,9 @@ async def search_current_user_dataset_records( search_engine=search_engine, dataset=dataset, search_records_query=body, - parsed_metadata=metadata.metadata_parsed, limit=limit, offset=offset, user=current_user, - response_statuses=response_statuses, - sort_by_query_param=sort_by_query_param, ) record_id_score_map: Dict[UUID, Dict[str, Union[float, SearchRecord, None]]] = { @@ -534,10 +376,7 @@ async def search_dataset_records( search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, body: SearchRecordsQuery, - metadata: MetadataQueryParams = Depends(), - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), @@ -555,9 +394,6 @@ async def search_dataset_records( search_records_query=body, limit=limit, offset=offset, - parsed_metadata=metadata.metadata_parsed, - response_statuses=response_statuses, - sort_by_query_param=sort_by_query_param, ) record_id_score_map = { diff --git a/argilla-server/src/argilla_server/api/schemas/v1/records.py b/argilla-server/src/argilla_server/api/schemas/v1/records.py index 0cf215954a..b5ff7c3f4c 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/records.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/records.py @@ -13,12 +13,9 @@ # limitations under the License. from datetime import datetime - from typing import Annotated, Any, Dict, List, Literal, Optional, Union from uuid import UUID -import fastapi - from argilla_server.api.schemas.v1.commons import UpdateSchema from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName from argilla_server.api.schemas.v1.responses import Response, ResponseFilterScope, UserResponseCreate @@ -223,15 +220,6 @@ def __init__(self, string: str): self.value: str = "".join(v).strip() -class MetadataQueryParams(BaseModel): - metadata: List[str] = Field(fastapi.Query([], pattern=r"^(?=.*[a-z0-9])[a-z0-9_-]+:(.+(,(.+))*)$")) - - @property - def metadata_parsed(self) -> List[MetadataParsedQueryParam]: - # TODO: Validate metadata fields names from query params - return [MetadataParsedQueryParam(q) for q in self.metadata] - - class VectorQuery(BaseModel): name: str record_id: Optional[UUID] = None diff --git a/argilla-server/src/argilla_server/search_engine/base.py b/argilla-server/src/argilla_server/search_engine/base.py index ee1dbcc386..db5bc87e2a 100644 --- a/argilla-server/src/argilla_server/search_engine/base.py +++ b/argilla-server/src/argilla_server/search_engine/base.py @@ -15,17 +15,13 @@ from abc import ABCMeta, abstractmethod from contextlib import asynccontextmanager from typing import ( - Any, AsyncGenerator, - ClassVar, - Dict, Generic, Iterable, List, Optional, - Type, - TypeVar, Union, + TypeVar, ) from uuid import UUID @@ -38,16 +34,12 @@ SortOrder, ) from argilla_server.models import Dataset, MetadataProperty, Record, Response, Suggestion, User, Vector, VectorSettings -from argilla_server.pydantic_v1 import BaseModel, Field, root_validator +from argilla_server.pydantic_v1 import BaseModel, Field from argilla_server.pydantic_v1.generics import GenericModel __all__ = [ "SearchEngine", "TextQuery", - "MetadataFilter", - "TermsMetadataFilter", - "IntegerMetadataFilter", - "FloatMetadataFilter", "UserResponseStatusFilter", "SearchResponseItem", "SearchResponses", @@ -147,67 +139,6 @@ def has_pending_status(self) -> bool: return ResponseStatusFilter.pending in self.statuses or ResponseStatusFilter.missing in self.statuses -class MetadataFilter(BaseModel): - metadata_property: MetadataProperty - - class Config: - arbitrary_types_allowed = True - - @classmethod - @abstractmethod - def from_string(cls, metadata_property: MetadataProperty, string: str) -> "MetadataFilter": - pass - - -class TermsMetadataFilter(MetadataFilter): - values: List[str] - - @classmethod - def from_string(cls, metadata_property: MetadataProperty, string: str) -> "MetadataFilter": - return cls(metadata_property=metadata_property, values=string.split(",")) - - -NT = TypeVar("NT", int, float) - - -class _RangeModel(GenericModel, Generic[NT]): - ge: Optional[NT] - le: Optional[NT] - - -class NumericMetadataFilter(GenericModel, Generic[NT], MetadataFilter): - ge: Optional[NT] = None - le: Optional[NT] = None - - _json_model: ClassVar[Type[_RangeModel]] - - @root_validator(skip_on_failure=True) - def check_bounds(cls, values: Dict[str, Any]) -> Dict[str, Any]: - ge = values.get("ge") - le = values.get("le") - - if ge is None and le is None: - raise ValueError("One of 'ge' or 'le' values must be specified") - - if ge is not None and le is not None and ge > le: - raise ValueError(f"'ge' ({ge}) must be lower or equal than 'le' ({le})") - - return values - - @classmethod - def from_string(cls, metadata_property: MetadataProperty, string: str) -> "NumericMetadataFilter": - model = cls._json_model.parse_raw(string) - return cls(metadata_property=metadata_property, ge=model.ge, le=model.le) - - -class IntegerMetadataFilter(NumericMetadataFilter[int]): - _json_model = _RangeModel[int] - - -class FloatMetadataFilter(NumericMetadataFilter[float]): - _json_model = _RangeModel[float] - - class SearchResponseItem(BaseModel): record_id: UUID score: Optional[float] @@ -236,6 +167,9 @@ class TermCount(BaseModel): values: List[TermCount] = Field(default_factory=list) +NT = TypeVar("NT", int, float) + + class NumericMetadataMetrics(GenericModel, Generic[NT]): min: Optional[NT] max: Optional[NT] @@ -348,11 +282,6 @@ async def search( query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, - # TODO: remove them and keep filter and order - user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, - sort_by: Optional[List[SortBy]] = None, - # END TODO offset: int = 0, limit: int = 100, ) -> SearchResponses: @@ -378,10 +307,6 @@ async def similarity_search( record: Optional[Record] = None, query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, - # TODO: remove them and keep filter - user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, - # END TODO max_results: int = 100, order: SimilarityOrder = SimilarityOrder.most_similar, threshold: Optional[float] = None, diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index a081105d16..e6541309a0 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -38,11 +38,8 @@ AndFilter, Filter, FilterScope, - FloatMetadataFilter, FloatMetadataMetrics, - IntegerMetadataFilter, IntegerMetadataMetrics, - MetadataFilter, MetadataFilterScope, MetadataMetrics, Order, @@ -55,7 +52,6 @@ SortBy, SuggestionFilterScope, TermsFilter, - TermsMetadataFilter, TermsMetadataMetrics, TextQuery, UserResponseStatusFilter, @@ -97,9 +93,6 @@ def es_bool_query( if must_not: bool_query["must_not"] = must_not - if not bool_query: - raise ValueError("Cannot build a boolean query without any clause") - if minimum_should_match: bool_query["minimum_should_match"] = minimum_should_match @@ -210,55 +203,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str: return str(vector_settings.id) -# This function will be moved once the `metadata_filters` argument is removed from search and similarity_search methods -def _unify_metadata_filters_with_filter(metadata_filters: List[MetadataFilter], filter: Optional[Filter]) -> Filter: - filters = [] - if filter: - filters.append(filter) - - for metadata_filter in metadata_filters: - metadata_scope = MetadataFilterScope(metadata_property=metadata_filter.metadata_property.name) - if isinstance(metadata_filter, TermsMetadataFilter): - new_filter = TermsFilter(scope=metadata_scope, values=metadata_filter.values) - elif isinstance(metadata_filter, (IntegerMetadataFilter, FloatMetadataFilter)): - new_filter = RangeFilter(scope=metadata_scope, ge=metadata_filter.ge, le=metadata_filter.le) - else: - raise ValueError(f"Cannot process request for metadata filter {metadata_filter}") - filters.append(new_filter) - - return AndFilter(filters=filters) - - -# This function will be moved once the response status filter is removed from search and similarity_search methods -def _unify_user_response_status_filter_with_filter( - user_response_status_filter: UserResponseStatusFilter, filter: Optional[Filter] = None -) -> Filter: - scope = ResponseFilterScope(user=user_response_status_filter.user, property="status") - response_filter = TermsFilter(scope=scope, values=[status.value for status in user_response_status_filter.statuses]) - - if filter: - return AndFilter(filters=[filter, response_filter]) - else: - return response_filter - - -# This function will be moved once the `sort_by` argument is removed from search and similarity_search methods -def _unify_sort_by_with_order(sort_by: List[SortBy], order: List[Order]) -> List[Order]: - if order: - return order - - new_order = [] - for sort in sort_by: - if isinstance(sort.field, MetadataProperty): - scope = MetadataFilterScope(metadata_property=sort.field.name) - else: - scope = RecordFilterScope(property=sort.field) - - new_order.append(Order(scope=scope, order=sort.order)) - - return new_order - - def is_response_status_scope(scope: FilterScope) -> bool: return isinstance(scope, ResponseFilterScope) and scope.property == "status" and scope.question is None @@ -370,14 +314,14 @@ async def update_record_response(self, response: Response): es_responses = self._map_record_responses_to_es([response]) - await self._update_document_request(index_name, id=record.id, body={"doc": {"responses": es_responses}}) + await self._update_document_request(index_name, id=str(record.id), body={"doc": {"responses": es_responses}}) async def delete_record_response(self, response: Response): record = response.record index_name = await self._get_dataset_index(record.dataset) await self._update_document_request( - index_name, id=record.id, body={"script": es_script_for_delete_user_response(response.user)} + index_name, id=str(record.id), body={"script": es_script_for_delete_user_response(response.user)} ) async def update_record_suggestion(self, suggestion: Suggestion): @@ -387,7 +331,7 @@ async def update_record_suggestion(self, suggestion: Suggestion): await self._update_document_request( index_name, - id=suggestion.record_id, + id=str(suggestion.record_id), body={"doc": {"suggestions": es_suggestions}}, ) @@ -396,7 +340,7 @@ async def delete_record_suggestion(self, suggestion: Suggestion): await self._update_document_request( index_name, - id=suggestion.record_id, + id=str(suggestion.record_id), body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'}, ) @@ -423,21 +367,10 @@ async def similarity_search( record: Optional[Record] = None, query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, - # TODO: remove them and keep filter - user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, - # END TODO max_results: int = 100, order: SimilarityOrder = SimilarityOrder.most_similar, threshold: Optional[float] = None, ) -> SearchResponses: - # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if metadata_filters: - filter = _unify_metadata_filters_with_filter(metadata_filters, filter) - if user_response_status_filter and user_response_status_filter.statuses: - filter = _unify_user_response_status_filter_with_filter(user_response_status_filter, filter) - # END TODO - if bool(value) == bool(record): raise ValueError("Must provide either vector value or record to compute the similarity search") @@ -629,26 +562,11 @@ async def search( query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, - # TODO: Remove these arguments - user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, - sort_by: Optional[List[SortBy]] = None, - # END TODO offset: int = 0, limit: int = 100, user_id: Optional[str] = None, ) -> SearchResponses: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html - - # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if metadata_filters: - filter = _unify_metadata_filters_with_filter(metadata_filters, filter) - if user_response_status_filter and user_response_status_filter.statuses: - filter = _unify_user_response_status_filter_with_filter(user_response_status_filter, filter) - - if sort_by: - sort = _unify_sort_by_with_order(sort_by, sort) - # END TODO index = await self._get_dataset_index(dataset) text_query = self._build_text_query(dataset, text=query) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 73077c4381..5e3c6653de 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -316,12 +316,9 @@ async def test_with_filter( RangeFilter(scope=SuggestionFilterScope(question=question.name, property="score"), ge=0.5), ] ), - metadata_filters=[], offset=0, limit=50, query=None, - sort_by=None, - user_response_status_filter=None, ) async def test_with_sort( @@ -367,12 +364,9 @@ async def test_with_sort( Order(scope=ResponseFilterScope(question=question.name), order=SortOrder.asc), Order(scope=SuggestionFilterScope(question=question.name, property="score"), order=SortOrder.desc), ], - metadata_filters=[], offset=0, limit=50, query=None, - sort_by=None, - user_response_status_filter=None, ) async def test_with_invalid_filter(self, async_client: AsyncClient, owner_auth_header: dict): diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 557cb4de70..a259baa773 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -14,11 +14,13 @@ import math import uuid from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type from unittest.mock import ANY, MagicMock from uuid import UUID, uuid4 import pytest +from sqlalchemy import func, inspect, select + from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla_server.api.schemas.v1.datasets import DATASET_GUIDELINES_MAX_LENGTH, DATASET_NAME_MAX_LENGTH from argilla_server.api.schemas.v1.fields import FIELD_CREATE_NAME_MAX_LENGTH, FIELD_CREATE_TITLE_MAX_LENGTH @@ -41,6 +43,7 @@ ResponseStatusFilter, SimilarityOrder, RecordStatus, + SortOrder, ) from argilla_server.models import ( Dataset, @@ -57,19 +60,18 @@ VectorSettings, ) from argilla_server.search_engine import ( - FloatMetadataFilter, - IntegerMetadataFilter, - MetadataFilter, SearchEngine, SearchResponseItem, SearchResponses, - SortBy, - TermsMetadataFilter, TextQuery, - UserResponseStatusFilter, + AndFilter, + TermsFilter, + MetadataFilterScope, + RangeFilter, + ResponseFilterScope, + Order, + RecordFilterScope, ) -from sqlalchemy import func, inspect, select - from tests.factories import ( AdminFactory, AnnotatorFactory, @@ -80,7 +82,6 @@ LabelSelectionQuestionFactory, MetadataPropertyFactory, MultiLabelSelectionQuestionFactory, - OwnerFactory, QuestionFactory, RatingQuestionFactory, RecordFactory, @@ -3592,11 +3593,8 @@ async def test_search_current_user_dataset_records( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 @@ -3633,55 +3631,85 @@ async def test_search_current_user_dataset_records( } @pytest.mark.parametrize( - ("property_config", "param_value", "expected_filter_class", "expected_filter_args"), + ("property_config", "metadata_filter", "expected_filter"), [ ( {"name": "terms_prop", "settings": {"type": "terms"}}, - "value", - TermsMetadataFilter, - dict(values=["value"]), + { + "type": "terms", + "values": ["value"], + "scope": {"entity": "metadata", "metadata_property": "terms_prop"}, + }, + TermsFilter(scope=MetadataFilterScope(metadata_property="terms_prop"), values=["value"]), ), ( {"name": "terms_prop", "settings": {"type": "terms"}}, - "value1,value2", - TermsMetadataFilter, - dict(values=["value1", "value2"]), + { + "type": "terms", + "values": ["value1", "value2"], + "scope": {"entity": "metadata", "metadata_property": "terms_prop"}, + }, + TermsFilter(scope=MetadataFilterScope(metadata_property="terms_prop"), values=["value1", "value2"]), ), ( {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 10, "le": 20}', - IntegerMetadataFilter, - dict(ge=10, le=20), + { + "type": "range", + "ge": 10, + "le": 20, + "scope": {"entity": "metadata", "metadata_property": "integer_prop"}, + }, + RangeFilter( + scope=MetadataFilterScope(metadata_property="integer_prop"), + ge=10, + le=20, + ), ), ( {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 20}', - IntegerMetadataFilter, - dict(ge=20, high=None), + {"type": "range", "ge": 20, "scope": {"entity": "metadata", "metadata_property": "integer_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="integer_prop"), + ge=20, + ), ), ( {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"le": 20}', - IntegerMetadataFilter, - dict(low=None, le=20), + {"type": "range", "le": 20, "scope": {"entity": "metadata", "metadata_property": "integer_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="integer_prop"), + le=20, + ), ), ( {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": -1.30, "le": 23.23}', - FloatMetadataFilter, - dict(ge=-1.30, le=23.23), + { + "type": "range", + "ge": -1.30, + "le": 23.23, + "scope": {"entity": "metadata", "metadata_property": "float_prop"}, + }, + RangeFilter( + scope=MetadataFilterScope(metadata_property="float_prop"), + ge=-1.30, + le=23.23, + ), ), ( {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": 23.23}', - FloatMetadataFilter, - dict(ge=23.23, high=None), + {"type": "range", "ge": 23.23, "scope": {"entity": "metadata", "metadata_property": "float_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="float_prop"), + ge=23.23, + ), ), ( {"name": "float_prop", "settings": {"type": "float"}}, - '{"le": 11.32}', - FloatMetadataFilter, - dict(low=None, le=11.32), + {"type": "range", "le": 11.32, "scope": {"entity": "metadata", "metadata_property": "float_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="float_prop"), + le=11.32, + ), ), ], ) @@ -3691,15 +3719,14 @@ async def test_search_current_user_dataset_records_with_metadata_filter( mock_search_engine: SearchEngine, owner: User, owner_auth_header: dict, - property_config: dict, - param_value: str, - expected_filter_class: Type[MetadataFilter], - expected_filter_args: dict, + property_config, + metadata_filter: dict, + expected_filter: Any, ): workspace = await WorkspaceFactory.create() dataset, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - metadata_property = await MetadataPropertyFactory.create( + await MetadataPropertyFactory.create( name=property_config["name"], settings=property_config["settings"], dataset=dataset, @@ -3713,12 +3740,9 @@ async def test_search_current_user_dataset_records_with_metadata_filter( ], ) - params = {"metadata": [f"{metadata_property.name}:{param_value}"]} - - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}, "filters": {"and": [metadata_filter]}} response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params=params, headers=owner_auth_header, json=query_json, ) @@ -3727,91 +3751,45 @@ async def test_search_current_user_dataset_records_with_metadata_filter( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[expected_filter_class(metadata_property=metadata_property, **expected_filter_args)], - user_response_status_filter=None, + filter=AndFilter(filters=[expected_filter]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) @pytest.mark.parametrize( - ("property_config", "wrong_value"), + "sort,expected_sort", [ - ({"name": "terms_prop", "settings": {"type": "terms"}}, None), - ({"name": "terms_prop", "settings": {"type": "terms"}}, "terms_prop"), - ({"name": "terms_prop", "settings": {"type": "terms"}}, "terms_prop:"), - ({"name": "terms_prop", "settings": {"type": "terms"}}, "wrong-value"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, None), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "integer_prop"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "integer_prop:"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "integer_prop:{}"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "wrong-value"), - ({"name": "float_prop", "settings": {"type": "float"}}, None), - ({"name": "float_prop", "settings": {"type": "float"}}, "float_prop"), - ({"name": "float_prop", "settings": {"type": "float"}}, "float_prop:"), - ({"name": "float_prop", "settings": {"type": "float"}}, "float_prop:{}"), - ({"name": "float_prop", "settings": {"type": "float"}}, "wrong-value"), - ], - ) - async def test_search_current_user_dataset_records_with_wrong_metadata_filter_values( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: User, - owner_auth_header: dict, - property_config: dict, - wrong_value: str, - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - - await MetadataPropertyFactory.create( - name=property_config["name"], - settings=property_config["settings"], - dataset=dataset, - ) - - mock_search_engine.search.return_value = SearchResponses( - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - total=2, - ) - - params = {"metadata": [wrong_value]} - - query_json = {"query": {"text": {"q": "Hello"}}} - response = await async_client.post( - f"/api/v1/me/datasets/{dataset.id}/records/search", - params=params, - headers=owner_auth_header, - json=query_json, - ) - assert response.status_code == 422, response.json() - - @pytest.mark.parametrize( - "sorts", - [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], + ( + [{"scope": {"entity": "record", "property": "inserted_at"}, "order": "asc"}], + [Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.asc)], + ), + ( + [{"scope": {"entity": "record", "property": "inserted_at"}, "order": "desc"}], + [Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.desc)], + ), + ( + [{"scope": {"entity": "record", "property": "updated_at"}, "order": "asc"}], + [Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.asc)], + ), + ( + [{"scope": {"entity": "record", "property": "updated_at"}, "order": "desc"}], + [Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc)], + ), + ( + [{"scope": {"entity": "metadata", "metadata_property": "terms-metadata-property"}, "order": "asc"}], + [Order(scope=MetadataFilterScope(metadata_property="terms-metadata-property"), order=SortOrder.asc)], + ), + ( + [ + {"scope": {"entity": "record", "property": "updated_at"}, "order": "desc"}, + {"scope": {"entity": "metadata", "metadata_property": "terms-metadata-property"}, "order": "desc"}, + ], + [ + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc), + Order(scope=MetadataFilterScope(metadata_property="terms-metadata-property"), order=SortOrder.desc), + ], + ), ], ) async def test_search_current_user_dataset_records_with_sort_by( @@ -3820,16 +3798,15 @@ async def test_search_current_user_dataset_records_with_sort_by( mock_search_engine: SearchEngine, owner: "User", owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], + sort: List[dict], + expected_sort: List[Order], ): workspace = await WorkspaceFactory.create() dataset, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) + for order in expected_sort: + if isinstance(order.scope, MetadataFilterScope): + await TermsMetadataPropertyFactory.create(name=order.scope.metadata_property, dataset=dataset) mock_search_engine.search.return_value = SearchResponses( total=2, @@ -3839,15 +3816,13 @@ async def test_search_current_user_dataset_records_with_sort_by( ], ) - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": sort, } - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} - response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params=query_params, headers=owner_auth_header, json=query_json, ) @@ -3857,11 +3832,9 @@ async def test_search_current_user_dataset_records_with_sort_by( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, + sort=expected_sort, user_id=owner.id, ) @@ -3871,18 +3844,17 @@ async def test_search_current_user_dataset_records_with_sort_by_with_wrong_sort_ workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [{"scope": {"entity": "record", "property": "wrong_property"}, "order": "asc"}], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "inserted_at:wrong"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } async def test_search_current_user_dataset_records_with_sort_by_with_non_existent_metadata_property( self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict @@ -3890,17 +3862,19 @@ async def test_search_current_user_dataset_records_with_sort_by_with_non_existen workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [{"scope": {"entity": "metadata", "metadata_property": "missing"}, "order": "asc"}], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "metadata.i-do-not-exist:asc"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." + "detail": f"MetadataProperty not found filtering by name=missing, dataset_id={dataset.id}" } async def test_search_current_user_dataset_records_with_sort_by_with_invalid_field( @@ -3909,19 +3883,19 @@ async def test_search_current_user_dataset_records_with_sort_by_with_invalid_fie workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [ + {"scope": {"entity": "wrong", "property": "wrong"}, "order": "asc"}, + ], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "not-valid"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. " - "It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } @pytest.mark.parametrize( "includes", @@ -4063,9 +4037,6 @@ async def test_search_current_user_dataset_records_with_include( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], - sort_by=None, - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, user_id=owner.id, @@ -4268,22 +4239,37 @@ async def test_search_current_user_dataset_records_with_response_status_filter( dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) mock_search_engine.search.return_value = SearchResponses(items=[]) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "filters": { + "and": [ + { + "type": "terms", + "scope": {"entity": "response", "property": "status"}, + "values": [ResponseStatus.submitted], + } + ] + }, + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", headers=owner_auth_header, json=query_json, - params={"response_status": ResponseStatus.submitted.value}, ) mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], - user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatusFilter.submitted]), + filter=AndFilter( + filters=[ + TermsFilter( + scope=ResponseFilterScope(property="status", user=owner), + values=[ResponseStatusFilter.submitted], + ) + ] + ), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 @@ -4326,8 +4312,6 @@ async def test_search_current_user_dataset_records_with_record_vector( query=None, order=SimilarityOrder.most_similar, max_results=5, - metadata_filters=[], - user_response_status_filter=None, ) async def test_search_current_user_dataset_records_with_vector_value( @@ -4370,8 +4354,6 @@ async def test_search_current_user_dataset_records_with_vector_value( query=None, order=SimilarityOrder.most_similar, max_results=10, - metadata_filters=[], - user_response_status_filter=None, ) async def test_search_current_user_dataset_records_with_vector_value_and_query( @@ -4419,8 +4401,6 @@ async def test_search_current_user_dataset_records_with_vector_value_and_query( query=TextQuery(q="Test query"), order=SimilarityOrder.most_similar, max_results=10, - metadata_filters=[], - user_response_status_filter=None, ) async def test_search_current_user_dataset_records_with_wrong_vector( @@ -4512,11 +4492,8 @@ async def test_search_current_user_dataset_records_with_offset_and_limit( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], - user_response_status_filter=None, offset=0, limit=5, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index 8f78940df3..4f989e5399 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -12,43 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Type, Union -from uuid import uuid4 +from typing import List, Optional, Tuple, Union import pytest -from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT -from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import RecordInclude, RecordSortField, ResponseStatus, UserRole, RecordStatus -from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace -from argilla_server.search_engine import ( - FloatMetadataFilter, - IntegerMetadataFilter, - MetadataFilter, - SearchEngine, - SearchResponseItem, - SearchResponses, - SortBy, - TermsMetadataFilter, -) from httpx import AsyncClient +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import RecordInclude, ResponseStatus +from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace from tests.factories import ( AdminFactory, AnnotatorFactory, DatasetFactory, LabelSelectionQuestionFactory, - MetadataPropertyFactory, RecordFactory, ResponseFactory, SuggestionFactory, - TermsMetadataPropertyFactory, TextFieldFactory, TextQuestionFactory, - UserFactory, VectorFactory, VectorSettingsFactory, WorkspaceFactory, - WorkspaceUserFactory, ) @@ -398,108 +382,6 @@ async def create_records_with_response( for record in await RecordFactory.create_batch(size=num_records, dataset=dataset): await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status) - @pytest.mark.parametrize( - ("property_config", "param_value", "expected_filter_class", "expected_filter_args"), - [ - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value", - TermsMetadataFilter, - dict(values=["value"]), - ), - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value1,value2", - TermsMetadataFilter, - dict(values=["value1", "value2"]), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 10, "le": 20}', - IntegerMetadataFilter, - dict(ge=10, le=20), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 20}', - IntegerMetadataFilter, - dict(ge=20, high=None), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"le": 20}', - IntegerMetadataFilter, - dict(ge=None, le=20), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": -1.30, "le": 23.23}', - FloatMetadataFilter, - dict(ge=-1.30, le=23.23), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": 23.23}', - FloatMetadataFilter, - dict(ge=23.23, high=None), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"le": 11.32}', - FloatMetadataFilter, - dict(ge=None, le=11.32), - ), - ], - ) - async def test_list_dataset_records_with_metadata_filter( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: User, - owner_auth_header: dict, - property_config: dict, - param_value: str, - expected_filter_class: Type[MetadataFilter], - expected_filter_args: dict, - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - metadata_property = await MetadataPropertyFactory.create( - name=property_config["name"], - settings=property_config["settings"], - dataset=dataset, - ) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = {"metadata": [f"{metadata_property.name}:{param_value}"]} - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - - response_json = response.json() - assert response_json["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[expected_filter_class(metadata_property=metadata_property, **expected_filter_args)], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=[SortBy(field=RecordSortField.inserted_at)], - ) - @pytest.mark.skip(reason="Factory integration with search engine") @pytest.mark.parametrize( "response_status_filter", ["missing", "pending", "discarded", "submitted", "draft", ["submitted", "draft"]] @@ -563,121 +445,6 @@ async def test_list_dataset_records_with_response_status_filter( ] ) - @pytest.mark.parametrize( - "sorts", - [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - ], - ) - async def test_list_dataset_records_with_sort_by( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: "User", - owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] - } - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - assert response.json()["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, - ) - - async def test_list_dataset_records_with_sort_by_with_wrong_sort_order_value( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", params={"sort_by": "inserted_at:wrong"}, headers=owner_auth_header - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } - - async def test_list_dataset_records_with_sort_by_with_non_existent_metadata_property( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params={"sort_by": "metadata.i-do-not-exist:asc"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." - } - - async def test_list_dataset_records_with_sort_by_with_invalid_field( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params={"sort_by": "not-valid"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. " - "It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } - async def test_list_dataset_records_without_authentication(self, async_client: "AsyncClient"): dataset = await DatasetFactory.create() @@ -793,753 +560,3 @@ async def create_dataset_with_user_responses( ] return dataset, questions, records, responses, suggestions - - async def test_list_current_user_dataset_records( - self, async_client: "AsyncClient", mock_search_engine: SearchEngine, owner: User, owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - record_a, record_b, record_c = records - - mock_search_engine.search.return_value = SearchResponses( - total=3, - items=[ - SearchResponseItem(record_id=record_a.id, score=14.2), - SearchResponseItem(record_id=record_b.id, score=12.2), - SearchResponseItem(record_id=record_c.id, score=10.2), - ], - ) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header) - - assert response.status_code == 200 - assert response.json() == { - "total": 3, - "items": [ - { - "id": str(record_a.id), - "status": RecordStatus.pending, - "fields": {"input": "input_a", "output": "output_a"}, - "metadata": None, - "dataset_id": str(dataset.id), - "external_id": record_a.external_id, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "status": RecordStatus.pending, - "fields": {"input": "input_b", "output": "output_b"}, - "metadata": {"unit": "test"}, - "dataset_id": str(dataset.id), - "external_id": record_b.external_id, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "status": RecordStatus.pending, - "fields": {"input": "input_c", "output": "output_c"}, - "metadata": None, - "dataset_id": str(dataset.id), - "external_id": record_c.external_id, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - } - - async def test_list_current_user_dataset_records_with_filtered_metadata_as_annotator( - self, async_client: "AsyncClient", mock_search_engine: SearchEngine, owner: User - ): - workspace = await WorkspaceFactory.create() - user = await AnnotatorFactory.create() - await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=user.id) - - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - await TermsMetadataPropertyFactory.create( - name="key1", - dataset=dataset, - allowed_roles=[UserRole.admin, UserRole.annotator], - ) - await TermsMetadataPropertyFactory.create( - name="key2", - dataset=dataset, - allowed_roles=[UserRole.admin], - ) - await TermsMetadataPropertyFactory.create( - name="key3", - dataset=dataset, - allowed_roles=[UserRole.admin], - ) - - record = await RecordFactory.create( - dataset=dataset, - fields={"input": "input_b", "output": "output_b"}, - metadata_={"key1": "value1", "key2": "value2", "key3": "value3", "extra": "extra"}, - ) - - mock_search_engine.search.return_value = SearchResponses( - total=1, - items=[SearchResponseItem(record_id=record.id, score=14.2)], - ) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: user.api_key} - ) - - assert response.status_code == 200 - assert response.json() == { - "total": 1, - "items": [ - { - "id": str(record.id), - "status": RecordStatus.pending, - "fields": {"input": "input_b", "output": "output_b"}, - "metadata": {"key1": "value1"}, - "dataset_id": str(dataset.id), - "external_id": record.external_id, - "inserted_at": record.inserted_at.isoformat(), - "updated_at": record.updated_at.isoformat(), - } - ], - } - - @pytest.mark.skip(reason="Factory integration with search engine") - @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin, UserRole.owner]) - @pytest.mark.parametrize( - "includes", - [[RecordInclude.responses], [RecordInclude.suggestions], [RecordInclude.responses, RecordInclude.suggestions]], - ) - async def test_list_current_user_dataset_records_with_include( - self, async_client: "AsyncClient", role: UserRole, includes: List[RecordInclude] - ): - workspace = await WorkspaceFactory.create() - user = await UserFactory.create(workspaces=[workspace], role=role) - dataset, questions, records, responses, suggestions = await self.create_dataset_with_user_responses( - user, workspace - ) - record_a, record_b, record_c = records - response_a_user, response_b_user = responses[1], responses[3] - suggestion_a, suggestion_b = suggestions - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - params = [("include", include.value) for include in includes] - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", params=params, headers={API_KEY_HEADER_NAME: user.api_key} - ) - - expected = { - "total": 3, - "items": [ - { - "id": str(record_a.id), - "fields": {"input": "input_a", "output": "output_a"}, - "metadata": None, - "external_id": record_a.external_id, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "fields": {"input": "input_b", "output": "output_b"}, - "metadata": {"unit": "test"}, - "external_id": record_b.external_id, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "fields": {"input": "input_c", "output": "output_c"}, - "metadata": None, - "external_id": record_c.external_id, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - } - - if RecordInclude.responses in includes: - expected["items"][0]["responses"] = [ - { - "id": str(response_a_user.id), - "values": None, - "status": "discarded", - "user_id": str(user.id), - "inserted_at": response_a_user.inserted_at.isoformat(), - "updated_at": response_a_user.updated_at.isoformat(), - } - ] - expected["items"][1]["responses"] = [ - { - "id": str(response_b_user.id), - "values": { - "input_ok": {"value": "no"}, - "output_ok": {"value": "no"}, - }, - "status": "submitted", - "user_id": str(user.id), - "inserted_at": response_b_user.inserted_at.isoformat(), - "updated_at": response_b_user.updated_at.isoformat(), - }, - ] - expected["items"][2]["responses"] = [] - - if RecordInclude.suggestions in includes: - expected["items"][0]["suggestions"] = [ - { - "id": str(suggestion_a.id), - "value": "option-1", - "score": None, - "agent": None, - "type": None, - "question_id": str(questions[0].id), - } - ] - expected["items"][1]["suggestions"] = [ - { - "id": str(suggestion_b.id), - "value": "option-2", - "score": 0.75, - "agent": "unit-test-agent", - "type": "model", - "question_id": str(questions[0].id), - } - ] - expected["items"][2]["suggestions"] = [] - - assert response.status_code == 200 - assert response.json() == expected - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_include_vectors( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(dataset=dataset) - record_b = await RecordFactory.create(dataset=dataset) - record_c = await RecordFactory.create(dataset=dataset) - vector_settings_a = await VectorSettingsFactory.create(name="vector-a", dimensions=3, dataset=dataset) - vector_settings_b = await VectorSettingsFactory.create(name="vector-b", dimensions=2, dataset=dataset) - - await VectorFactory.create(value=[1.0, 2.0, 3.0], vector_settings=vector_settings_a, record=record_a) - await VectorFactory.create(value=[4.0, 5.0], vector_settings=vector_settings_b, record=record_a) - await VectorFactory.create(value=[1.0, 2.0], vector_settings=vector_settings_b, record=record_b) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"include": RecordInclude.vectors.value}, - headers=owner_auth_header, - ) - - assert response.status_code == 200 - assert response.json() == { - "items": [ - { - "id": str(record_a.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_a.external_id, - "vectors": { - "vector-a": [1.0, 2.0, 3.0], - "vector-b": [4.0, 5.0], - }, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_b.external_id, - "vectors": { - "vector-b": [1.0, 2.0], - }, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_c.external_id, - "vectors": {}, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - "total": 3, - } - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_include_specific_vectors( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(dataset=dataset) - record_b = await RecordFactory.create(dataset=dataset) - record_c = await RecordFactory.create(dataset=dataset) - vector_settings_a = await VectorSettingsFactory.create(name="vector-a", dimensions=3, dataset=dataset) - vector_settings_b = await VectorSettingsFactory.create(name="vector-b", dimensions=2, dataset=dataset) - vector_settings_c = await VectorSettingsFactory.create(name="vector-c", dimensions=4, dataset=dataset) - - await VectorFactory.create(value=[1.0, 2.0, 3.0], vector_settings=vector_settings_a, record=record_a) - await VectorFactory.create(value=[4.0, 5.0], vector_settings=vector_settings_b, record=record_a) - await VectorFactory.create(value=[6.0, 7.0, 8.0, 9.0], vector_settings=vector_settings_c, record=record_a) - await VectorFactory.create(value=[1.0, 2.0], vector_settings=vector_settings_b, record=record_b) - await VectorFactory.create(value=[10.0, 11.0, 12.0, 13.0], vector_settings=vector_settings_c, record=record_b) - await VectorFactory.create(value=[14.0, 15.0, 16.0, 17.0], vector_settings=vector_settings_c, record=record_c) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"include": f"{RecordInclude.vectors.value}:{vector_settings_a.name},{vector_settings_b.name}"}, - headers=owner_auth_header, - ) - - assert response.status_code == 200 - assert response.json() == { - "items": [ - { - "id": str(record_a.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_a.external_id, - "vectors": { - "vector-a": [1.0, 2.0, 3.0], - "vector-b": [4.0, 5.0], - }, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_b.external_id, - "vectors": { - "vector-b": [1.0, 2.0], - }, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_c.external_id, - "vectors": {}, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - "total": 3, - } - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_offset( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - record_c = await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params={"offset": 2} - ) - - assert response.status_code == 200 - - response_body = response.json() - assert [item["id"] for item in response_body["items"]] == [str(record_c.id)] - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_limit( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params={"limit": 1} - ) - - assert response.status_code == 200 - - response_body = response.json() - assert [item["id"] for item in response_body["items"]] == [str(record_a.id)] - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_offset_and_limit( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - record_c = await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params={"offset": 1, "limit": 1} - ) - - assert response.status_code == 200 - - response_body = response.json() - assert [item["id"] for item in response_body["items"]] == [str(record_c.id)] - - @pytest.mark.parametrize( - ("property_config", "param_value", "expected_filter_class", "expected_filter_args"), - [ - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value", - TermsMetadataFilter, - dict(values=["value"]), - ), - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value1,value2", - TermsMetadataFilter, - dict(values=["value1", "value2"]), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 10, "le": 20}', - IntegerMetadataFilter, - dict(ge=10, le=20), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 20}', - IntegerMetadataFilter, - dict(ge=20, le=None), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"le": 20}', - IntegerMetadataFilter, - dict(ge=None, le=20), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": -1.30, "le": 23.23}', - FloatMetadataFilter, - dict(ge=-1.30, le=23.23), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": 23.23}', - FloatMetadataFilter, - dict(ge=23.23, le=None), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"le": 11.32}', - FloatMetadataFilter, - dict(ge=None, le=11.32), - ), - ], - ) - async def test_list_current_user_dataset_records_with_metadata_filter( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: User, - owner_auth_header: dict, - property_config: dict, - param_value: str, - expected_filter_class: Type[MetadataFilter], - expected_filter_args: dict, - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - metadata_property = await MetadataPropertyFactory.create( - name=property_config["name"], - settings=property_config["settings"], - dataset=dataset, - ) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = {"metadata": [f"{metadata_property.name}:{param_value}"]} - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - - response_json = response.json() - assert response_json["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[expected_filter_class(metadata_property=metadata_property, **expected_filter_args)], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, - user_id=owner.id, - ) - - @pytest.mark.skip(reason="Factory integration with search engine") - @pytest.mark.parametrize("response_status_filter", ["missing", "pending", "discarded", "submitted", "draft"]) - async def test_list_current_user_dataset_records_with_response_status_filter( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict, response_status_filter: str - ): - num_responses_per_status = 10 - response_values = {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}} - - dataset = await DatasetFactory.create() - # missing responses - await RecordFactory.create_batch(size=num_responses_per_status, dataset=dataset) - # discarded responses - await self.create_records_with_response(num_responses_per_status, dataset, owner, ResponseStatus.discarded) - # submitted responses - await self.create_records_with_response( - num_responses_per_status, dataset, owner, ResponseStatus.submitted, response_values - ) - # drafted responses - await self.create_records_with_response( - num_responses_per_status, dataset, owner, ResponseStatus.draft, response_values - ) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records?response_status={response_status_filter}&include=responses", - headers=owner_auth_header, - ) - - assert response.status_code == 200 - response_json = response.json() - - assert len(response_json["items"]) == num_responses_per_status - - if response_status_filter in ["missing", "pending"]: - assert all([len(record["responses"]) == 0 for record in response_json["items"]]) - else: - assert all( - [record["responses"][0]["status"] == response_status_filter for record in response_json["items"]] - ) - - @pytest.mark.parametrize( - "sorts", - [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - ], - ) - async def test_list_current_user_dataset_records_with_sort_by( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: "User", - owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] - } - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - assert response.json()["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, - user_id=owner.id, - ) - - async def test_list_current_user_dataset_records_with_sort_by_with_wrong_sort_order_value( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"sort_by": "inserted_at:wrong"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } - - async def test_list_current_user_dataset_records_with_sort_by_with_non_existent_metadata_property( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"sort_by": "metadata.i-do-not-exist:asc"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." - } - - async def test_list_current_user_dataset_records_with_sort_by_with_invalid_field( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"sort_by": "not-valid"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } - - async def test_list_current_user_dataset_records_without_authentication(self, async_client: "AsyncClient"): - dataset = await DatasetFactory.create() - - response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/records") - - assert response.status_code == 401 - - @pytest.mark.skip(reason="Factory integration with search engine") - @pytest.mark.parametrize("role", [UserRole.admin, UserRole.annotator]) - async def test_list_current_user_dataset_records_as_restricted_user( - self, async_client: "AsyncClient", role: UserRole - ): - workspace = await WorkspaceFactory.create() - user = await UserFactory.create(workspaces=[workspace], role=role) - dataset = await DatasetFactory.create(workspace=workspace) - record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - record_b = await RecordFactory.create( - fields={"record_b": "value_b"}, metadata_={"unit": "test"}, dataset=dataset - ) - record_c = await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - expected_records = [record_a, record_b, record_c] - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: user.api_key} - ) - - assert response.status_code == 200 - - response_items = response.json()["items"] - - for expected_record in expected_records: - found_items = [item for item in response_items if item["id"] == str(expected_record.id)] - assert found_items, expected_record - - assert found_items[0] == { - "id": str(expected_record.id), - "fields": expected_record.fields, - "metadata": expected_record.metadata_, - "external_id": expected_record.external_id, - "inserted_at": expected_record.inserted_at.isoformat(), - "updated_at": expected_record.updated_at.isoformat(), - } - - @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) - async def test_list_current_user_dataset_records_as_restricted_user_from_different_workspace( - self, async_client: "AsyncClient", role: UserRole - ): - dataset = await DatasetFactory.create() - workspace = await WorkspaceFactory.create() - user = await UserFactory.create(workspaces=[workspace], role=role) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: user.api_key} - ) - - assert response.status_code == 403 - - async def test_list_current_user_dataset_records_with_nonexistent_dataset_id( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset_id = uuid4() - - await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset_id}/records", - headers=owner_auth_header, - ) - - assert response.status_code == 404 - assert response.json() == {"detail": f"Dataset with id `{dataset_id}` not found"} diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index 6c927d42a2..f57c115492 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -16,18 +16,27 @@ import pytest import pytest_asyncio -from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder, RecordStatus +from argilla_server.enums import ( + MetadataPropertyType, + QuestionType, + ResponseStatusFilter, + SimilarityOrder, + RecordStatus, + SortOrder, +) from argilla_server.models import Dataset, Question, Record, User, VectorSettings from argilla_server.search_engine import ( - FloatMetadataFilter, - IntegerMetadataFilter, ResponseFilterScope, SortBy, SuggestionFilterScope, TermsFilter, - TermsMetadataFilter, TextQuery, UserResponseStatusFilter, + Filter, + MetadataFilterScope, + RangeFilter, + Order, + RecordFilterScope, ) from argilla_server.search_engine.commons import ( ALL_RESPONSES_STATUSES_FIELD, @@ -595,7 +604,7 @@ async def test_search_with_response_status_filter( result = await search_engine.search( test_banking_sentiment_dataset, query=TextQuery(q="payment"), - user_response_status_filter=UserResponseStatusFilter(user=user, statuses=statuses), + filter=TermsFilter(scope=ResponseFilterScope(property="status"), values=statuses), ) assert len(result.items) == expected_items assert result.total == expected_items @@ -669,26 +678,26 @@ async def test_search_with_response_status_filter_with_no_user( result = await search_engine.search( test_banking_sentiment_dataset, - user_response_status_filter=UserResponseStatusFilter(statuses=statuses, user=None), + filter=TermsFilter(ResponseFilterScope(property="status"), values=statuses), ) assert len(result.items) == expected_items assert result.total == expected_items @pytest.mark.parametrize( - ("metadata_filters_config", "expected_items"), + ("filter", "expected_items"), [ - ([{"name": "label", "values": ["neutral"]}], 4), - ([{"name": "label", "values": ["positive"]}], 1), - ([{"name": "label", "values": ["neutral", "positive"]}], 5), - ([{"name": "textId", "ge": 3, "le": 4}], 2), - ([{"name": "textId", "ge": 3, "le": 3}], 1), - ([{"name": "textId", "ge": 3}], 6), - ([{"name": "textId", "le": 4}], 5), - ([{"name": "seq_float", "ge": 0.0, "le": 12.03}], 3), - ([{"name": "seq_float", "ge": 0.13, "le": 0.13}], 1), - ([{"name": "seq_float", "ge": 0.0}], 7), - ([{"name": "seq_float", "le": 12.03}], 5), + (TermsFilter(scope=MetadataFilterScope(metadata_property="label"), values=["neutral"]), 4), + (TermsFilter(scope=MetadataFilterScope(metadata_property="label"), values=["positive"]), 1), + (TermsFilter(scope=MetadataFilterScope(metadata_property="label"), values=["neutral", "positive"]), 5), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), ge=3, le=4), 2), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), ge=3, le=3), 1), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), ge=3), 6), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), le=4), 5), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0, le=12.03), 3), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0.13, le=0.13), 1), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0.0), 7), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), le=12.03), 5), ], ) async def test_search_with_metadata_filter( @@ -696,24 +705,10 @@ async def test_search_with_metadata_filter( search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset, - metadata_filters_config: List[dict], + filter: Filter, expected_items: int, ): - metadata_filters = [] - for metadata_filter_config in metadata_filters_config: - name = metadata_filter_config.pop("name") - for metadata_property in test_banking_sentiment_dataset.metadata_properties: - if name == metadata_property.name: - if metadata_property.type == MetadataPropertyType.terms: - filter_cls = TermsMetadataFilter - elif metadata_property.type == MetadataPropertyType.integer: - filter_cls = IntegerMetadataFilter - else: - filter_cls = FloatMetadataFilter - metadata_filters.append(filter_cls(metadata_property=metadata_property, **metadata_filter_config)) - break - - result = await search_engine.search(test_banking_sentiment_dataset, metadata_filters=metadata_filters) + result = await search_engine.search(test_banking_sentiment_dataset, filter=filter) assert len(result.items) == expected_items assert result.total == expected_items @@ -748,7 +743,7 @@ async def test_search_with_response_status_filter_does_not_affect_the_result_sco results = await search_engine.search( test_banking_sentiment_dataset, query=TextQuery(q="payment"), - user_response_status_filter=UserResponseStatusFilter(user=user, statuses=all_statuses), + filter=TermsFilter(scope=ResponseFilterScope(property="status", user=user), values=all_statuses), ) assert len(no_filter_results.items) == len(results.items) @@ -834,12 +829,12 @@ async def test_search_with_pagination( assert all_results.items[offset : offset + limit] == results.items @pytest.mark.parametrize( - ("sort_by"), + ("sort_order"), [ - SortBy(field="inserted_at"), - SortBy(field="updated_at"), - SortBy(field="inserted_at", order="desc"), - SortBy(field="updated_at", order="desc"), + Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.asc), + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.asc), + Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.desc), + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc), ], ) async def test_search_with_sort_by( @@ -847,18 +842,15 @@ async def test_search_with_sort_by( search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset, - sort_by: SortBy, + sort_order: Order, ): def _local_sort_by(record: Record) -> Any: - if isinstance(sort_by.field, str): - return getattr(record, sort_by.field) - return record.metadata_[sort_by.field.name] + return getattr(record, sort_order.scope.property) - results = await search_engine.search(test_banking_sentiment_dataset, sort_by=[sort_by]) + results = await search_engine.search(test_banking_sentiment_dataset, sort=[sort_order]) records = test_banking_sentiment_dataset.records - if sort_by: - records = sorted(records, key=_local_sort_by, reverse=sort_by.order == "desc") + records = sorted(records, key=_local_sort_by, reverse=sort_order.order == "desc") assert [item.record_id for item in results.items] == [record.id for record in records] @@ -1348,32 +1340,34 @@ async def test_similarity_search_by_vector_value_with_order( assert responses.items[0].record_id != selected_record.id @pytest.mark.parametrize( - "user_response_status_filter", + "statuses", [ - None, - UserResponseStatusFilter(statuses=[ResponseStatusFilter.missing, ResponseStatusFilter.draft]), + [], + [ResponseStatusFilter.missing, ResponseStatusFilter.draft], ], ) - async def test_similarity_search_by_record_and_user_response_filter( + async def test_similarity_search_by_record_and_response_status_filter( self, search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset_with_vectors: Dataset, - user_response_status_filter: UserResponseStatusFilter, + statuses: List[ResponseStatusFilter], ): selected_record: Record = test_banking_sentiment_dataset_with_vectors.records[0] vector_settings: VectorSettings = test_banking_sentiment_dataset_with_vectors.vectors_settings[0] - if user_response_status_filter: + scope = ResponseFilterScope(property="status") + + if statuses: test_user = await UserFactory.create() - user_response_status_filter.user = test_user + scope.user = test_user responses = await search_engine.similarity_search( dataset=test_banking_sentiment_dataset_with_vectors, vector_settings=vector_settings, record=selected_record, max_results=1, - user_response_status_filter=user_response_status_filter, + filter=TermsFilter(scope=scope, values=statuses), ) assert responses.total == 1 From b456600b63220ad3846bd1db7a162122a90a3a52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Tue, 16 Jul 2024 09:39:21 +0200 Subject: [PATCH 24/34] improvement: capture and retry database concurrent update errors (#5227) # Description After investigate timeouts for PostgreSQL I have found that timeouts should not affect errors when a SERIALIZABLE transactions is rollbacked because another concurrent update error is raised. So the only way to support concurrent updates with PostgreSQL and SERIALIZABLE transactions is to capture errors and retry the transaction. This code has the following changes: * Start using `backoff` library to retry any of the possible CRUD context functions updating responses and record statuses, using SERIALIZABLE database sessions. * This change has the side effect of working with PostgreSQL and SQLite at the same time. * I have set a fixed time of 15 seconds as maximum time for retrying with exponential backoff. * I have moved search engine updates outside of the transaction block. * This should mitigate errors on high concurrency scenarios for PostgreSQL and SQLite: * For SQLite we have the additional setting to set a timeout if necessary. * I have changed `DEFAULT_DATABASE_SQLITE_TIMEOUT` value to `5` seconds so the backoff logic will handle possible problems with locked database errors and SQLite. Refs #5000 **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** - [x] Manually testing with PostgreSQL and SQLite, running benchmarks using 20 concurrent requests. - [x] Running test suite for PostgreSQL and SQLite. **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../src/argilla_server/constants.py | 2 +- .../src/argilla_server/contexts/datasets.py | 27 +++++++++++++------ .../responses/upsert_responses_in_bulk.py | 1 + .../tests/unit/commons/test_settings.py | 2 +- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/argilla-server/src/argilla_server/constants.py b/argilla-server/src/argilla_server/constants.py index b65f18ca0a..eb419d989c 100644 --- a/argilla-server/src/argilla_server/constants.py +++ b/argilla-server/src/argilla_server/constants.py @@ -25,7 +25,7 @@ DEFAULT_PASSWORD = "1234" DEFAULT_API_KEY = "argilla.apikey" -DEFAULT_DATABASE_SQLITE_TIMEOUT = 15 +DEFAULT_DATABASE_SQLITE_TIMEOUT = 5 DEFAULT_DATABASE_POSTGRESQL_POOL_SIZE = 15 DEFAULT_DATABASE_POSTGRESQL_MAX_OVERFLOW = 10 diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index b95fcde5e1..7f5e49b6dd 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import backoff import copy from datetime import datetime from typing import ( @@ -94,6 +95,8 @@ CREATE_DATASET_VECTOR_SETTINGS_MAX_COUNT = 5 +MAX_TIME_RETRY_SQLALCHEMY_ERROR = 15 + async def _touch_dataset_last_activity_at(db: AsyncSession, dataset: Dataset) -> None: await db.execute( @@ -805,6 +808,7 @@ async def delete_record(db: AsyncSession, search_engine: "SearchEngine", record: return record +@backoff.on_exception(backoff.expo, sqlalchemy.exc.SQLAlchemyError, max_time=MAX_TIME_RETRY_SQLALCHEMY_ERROR) async def create_response( db: AsyncSession, search_engine: SearchEngine, record: Record, user: User, response_create: ResponseCreate ) -> Response: @@ -828,16 +832,18 @@ async def create_response( await db.flush([response]) await _load_users_from_responses([response]) await _touch_dataset_last_activity_at(db, record.dataset) - await search_engine.update_record_response(response) await db.refresh(record, attribute_names=[Record.responses_submitted.key]) await distribution.update_record_status(db, record) - await search_engine.partial_record_update(record, status=record.status) await db.commit() + await search_engine.update_record_response(response) + await search_engine.partial_record_update(record, status=record.status) + return response +@backoff.on_exception(backoff.expo, sqlalchemy.exc.SQLAlchemyError, max_time=MAX_TIME_RETRY_SQLALCHEMY_ERROR) async def update_response( db: AsyncSession, search_engine: SearchEngine, response: Response, response_update: ResponseUpdate ): @@ -854,16 +860,18 @@ async def update_response( await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) - await search_engine.update_record_response(response) await db.refresh(response.record, attribute_names=[Record.responses_submitted.key]) await distribution.update_record_status(db, response.record) - await search_engine.partial_record_update(response.record, status=response.record.status) await db.commit() + await search_engine.update_record_response(response) + await search_engine.partial_record_update(response.record, status=response.record.status) + return response +@backoff.on_exception(backoff.expo, sqlalchemy.exc.SQLAlchemyError, max_time=MAX_TIME_RETRY_SQLALCHEMY_ERROR) async def upsert_response( db: AsyncSession, search_engine: SearchEngine, record: Record, user: User, response_upsert: ResponseUpsert ) -> Response: @@ -886,29 +894,32 @@ async def upsert_response( await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) - await search_engine.update_record_response(response) await db.refresh(record, attribute_names=[Record.responses_submitted.key]) await distribution.update_record_status(db, record) - await search_engine.partial_record_update(record, status=record.status) await db.commit() + await search_engine.update_record_response(response) + await search_engine.partial_record_update(record, status=record.status) + return response +@backoff.on_exception(backoff.expo, sqlalchemy.exc.SQLAlchemyError, max_time=MAX_TIME_RETRY_SQLALCHEMY_ERROR) async def delete_response(db: AsyncSession, search_engine: SearchEngine, response: Response) -> Response: async with db.begin_nested(): response = await response.delete(db, autocommit=False) await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) - await search_engine.delete_record_response(response) await db.refresh(response.record, attribute_names=[Record.responses_submitted.key]) await distribution.update_record_status(db, response.record) - await search_engine.partial_record_update(record=response.record, status=response.record.status) await db.commit() + await search_engine.delete_record_response(response) + await search_engine.partial_record_update(record=response.record, status=response.record.status) + return response diff --git a/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py b/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py index 520194e46a..cb801365e7 100644 --- a/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py +++ b/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py @@ -44,6 +44,7 @@ async def execute(self, responses: List[ResponseUpsert], user: User) -> List[Res raise errors.NotFoundError(f"Record with id `{item.record_id}` not found") await authorize(user, RecordPolicy.create_response(record)) + response = await datasets.upsert_response(self.db, self.search_engine, record, user, item) except Exception as err: responses_bulk_items.append(ResponseBulk(item=None, error=ResponseBulkError(detail=str(err)))) diff --git a/argilla-server/tests/unit/commons/test_settings.py b/argilla-server/tests/unit/commons/test_settings.py index 8215709b4a..1ef0b64849 100644 --- a/argilla-server/tests/unit/commons/test_settings.py +++ b/argilla-server/tests/unit/commons/test_settings.py @@ -75,7 +75,7 @@ def test_settings_database_url(url: str, expected_url: str, monkeypatch): def test_settings_default_database_sqlite_timeout(): - assert Settings().database_sqlite_timeout == 15 + assert Settings().database_sqlite_timeout == 5 def test_settings_database_sqlite_timeout(monkeypatch): From 8dd1c7e91651dd9581fd009c160af79d503cf0da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Tue, 16 Jul 2024 11:19:36 +0200 Subject: [PATCH 25/34] chore: update CHANGELOG.md --- argilla-server/CHANGELOG.md | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 3b7a833ff9..e2ff333f7e 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -20,6 +20,7 @@ These are the section headers that we use: - Added support to specify `distribution` attribute when creating a dataset. ([#5013](https://github.com/argilla-io/argilla/pull/5013)) - Added support to change `distribution` attribute when updating a dataset. ([#5028](https://github.com/argilla-io/argilla/pull/5028)) +- Added new `status` column to `records` table. ([#5132](https://github.com/argilla-io/argilla/pull/5132)) - Added new `ARGILLA_DATABASE_SQLITE_TIMEOUT` environment variable allowing to set transactions timeout for SQLite. ([#5213](https://github.com/argilla-io/argilla/pull/5213)) - Added new `ARGILLA_DATABASE_POSTGRESQL_POOL_SIZE` environment variable allowing to set the number of connections to keep open inside the database connection pool. ([#5220](https://github.com/argilla-io/argilla/pull/5220)) - Added new `ARGILLA_DATABASE_POSTGRESQL_MAX_OVERFLOW` environment variable allowing to set the number of connections that can be opened above and beyond the `ARGILLA_DATABASE_POSTGRESQL_POOL_SIZE` setting. ([#5220](https://github.com/argilla-io/argilla/pull/5220)) @@ -36,22 +37,15 @@ These are the section headers that we use: ### Removed -- [breaking] Remove deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) -- [breaking] Remove deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) - -### Removed - -- Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) +- [breaking] Removed deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) +- [breaking] Removed deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) +- [breaking] Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) - [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) - [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) - [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166)) ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) -### Changed - -- Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) - ### Removed - Removed all API v0 endpoints. ([#4852](https://github.com/argilla-io/argilla/pull/4852)) From 20ae66341c6069894d987a8c3e75ef3d1e2d7079 Mon Sep 17 00:00:00 2001 From: Leire Date: Thu, 18 Jul 2024 16:53:04 +0200 Subject: [PATCH 26/34] =?UTF-8?q?=F0=9F=94=80=20Update=20UI=20for=20distri?= =?UTF-8?q?bution=20task=20(#5219)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - [x] Update progress bar styles - [x] Show two decimals in the progress bar of the dataset list - [x] Remove donut chart and replace with small cards - [x] Replace my progress bar with team progress bar - [x] Show submitted info when panel is collapsed --------- Co-authored-by: Damián Pumar Co-authored-by: David Berenstein --- .../BaseCollapsablePanel.vue | 26 ++++-- .../base/base-progress/BaseLinearProgress.vue | 10 ++- .../BaseLinearProgressSkeleton.vue | 4 +- .../base-resizable/HorizontalResizable.vue | 1 + .../base/base-resizable/VerticalResizable.vue | 4 +- .../base/base-tabs/BaseTabsAndContent.vue | 11 ++- .../container/fields/RecordFieldsHeader.vue | 2 +- .../container/fields/RecordStatus.vue | 37 ++++++-- .../container/mode/BulkAnnotation.vue | 12 ++- .../container/mode/FocusAnnotation.vue | 12 ++- .../container/questions/QuestionsForm.vue | 3 - .../progress/AnnotationProgress.vue | 85 ++++++++++++------ .../progress/AnnotationProgressDetailed.vue | 89 +++++++++++-------- .../annotation/progress/BarProgress.vue | 67 -------------- .../annotation/progress/TeamProgress.vue | 69 ++++++++++++++ ...l.ts => useAnnotationProgressViewModel.ts} | 24 ++--- .../progress/useTeamProgressViewModel.ts | 27 ++++++ .../annotation/settings/SettingsInfo.vue | 63 ++++++++++++- .../annotation/settings/SettingsMetadata.vue | 1 + .../features/datasets/DatasetsTable.vue | 2 +- .../dataset-progress/DatasetProgress.vue | 10 +-- .../useDatasetProgressViewModel.ts | 57 ++++++------ .../pages/dataset/_id/settings.vue | 2 + .../dataset/_id/useDatasetSettingViewModel.ts | 41 +++++++-- argilla-frontend/translation/de.js | 7 +- argilla-frontend/translation/en.js | 15 +++- argilla-frontend/v1/di/di.ts | 10 ++- .../v1/domain/entities/dataset/Dataset.ts | 22 ++++- .../v1/domain/entities/dataset/Progress.ts | 12 ++- .../distribution/TaskDistribution.test.ts | 17 ++++ .../entities/distribution/TaskDistribution.ts | 7 ++ .../v1/domain/entities/record/Record.ts | 4 + .../v1/domain/entities/record/RecordStatus.ts | 5 +- .../v1/domain/entities/record/Records.test.ts | 54 +++++++++++ .../domain/services/ITeamProgressStorage.ts | 5 ++ .../usecases/get-dataset-progress-use-case.ts | 14 ++- .../get-records-by-criteria-use-case.ts | 2 + .../events/UpdateTeamProgressEventHandler.ts | 16 ++++ .../v1/infrastructure/events/index.ts | 1 + .../v1/infrastructure/events/useEvents.ts | 15 ++-- .../repositories/DatasetRepository.ts | 43 ++++++--- .../v1/infrastructure/services/useRoutes.ts | 3 +- .../storage/TeamProgressStorage.ts | 8 ++ .../v1/infrastructure/types/dataset.ts | 24 +++++ .../v1/infrastructure/types/record.ts | 1 + 45 files changed, 687 insertions(+), 257 deletions(-) delete mode 100644 argilla-frontend/components/features/annotation/progress/BarProgress.vue create mode 100644 argilla-frontend/components/features/annotation/progress/TeamProgress.vue rename argilla-frontend/components/features/annotation/progress/{useFeedbackTaskProgressViewModel.ts => useAnnotationProgressViewModel.ts} (51%) create mode 100644 argilla-frontend/components/features/annotation/progress/useTeamProgressViewModel.ts create mode 100644 argilla-frontend/v1/domain/entities/distribution/TaskDistribution.test.ts create mode 100644 argilla-frontend/v1/domain/entities/distribution/TaskDistribution.ts create mode 100644 argilla-frontend/v1/domain/services/ITeamProgressStorage.ts create mode 100644 argilla-frontend/v1/infrastructure/events/UpdateTeamProgressEventHandler.ts create mode 100644 argilla-frontend/v1/infrastructure/storage/TeamProgressStorage.ts diff --git a/argilla-frontend/components/base/base-collpasable-panel/BaseCollapsablePanel.vue b/argilla-frontend/components/base/base-collpasable-panel/BaseCollapsablePanel.vue index 574f66e5a3..86406cdfea 100644 --- a/argilla-frontend/components/base/base-collpasable-panel/BaseCollapsablePanel.vue +++ b/argilla-frontend/components/base/base-collpasable-panel/BaseCollapsablePanel.vue @@ -7,13 +7,18 @@ ]" > - - +
+ +
+ +
+ +
@@ -53,6 +58,13 @@ export default { border-top: 1px solid $black-10; &__header { + &__container { + width: 100%; + display: flex; + justify-content: space-between; + align-items: center; + } + width: 100%; display: flex; justify-content: space-between; diff --git a/argilla-frontend/components/base/base-progress/BaseLinearProgress.vue b/argilla-frontend/components/base/base-progress/BaseLinearProgress.vue index 0a23b48c8e..4f9b6d67fd 100644 --- a/argilla-frontend/components/base/base-progress/BaseLinearProgress.vue +++ b/argilla-frontend/components/base/base-progress/BaseLinearProgress.vue @@ -67,7 +67,7 @@ export default { }, methods: { getPercentage(value) { - return ((value / this.progressMax) * 100).toFixed(); + return ((value / this.progressMax) * 100).toFixed(2); }, getTrianglePosition(range) { if (!range) return; @@ -86,10 +86,10 @@ export default { -$progressHeight: 12px; +$progressHeight: 14px; $tooltipBackgroundColor: palette(grey, 600); $tooltipTriangleSize: 5px; -$borderRadius: 10px; +$borderRadius: 3px; .progress { $this: &; @@ -103,12 +103,14 @@ $borderRadius: 10px; height: $progressHeight; border-radius: $borderRadius; overflow: hidden; + background: palette(grey, 600); + box-shadow: 0 0 0 1px palette(white); } &__bar { position: relative; height: 100%; border-radius: $borderRadius; - margin: 0 -4px; + margin: 0 -1px; box-shadow: 0 0 0 1px palette(white); z-index: 1; &:after { diff --git a/argilla-frontend/components/base/base-progress/BaseLinearProgressSkeleton.vue b/argilla-frontend/components/base/base-progress/BaseLinearProgressSkeleton.vue index e29b01e287..c3a4d431d5 100644 --- a/argilla-frontend/components/base/base-progress/BaseLinearProgressSkeleton.vue +++ b/argilla-frontend/components/base/base-progress/BaseLinearProgressSkeleton.vue @@ -9,9 +9,9 @@ export default {}; -$progressHeight: 12px; +$progressHeight: 14px; $progressBackgroundColor: #f2f2f2; -$borderRadius: 10px; +$borderRadius: 3px; .progress__wrapper { height: $progressHeight; diff --git a/argilla-frontend/components/base/base-resizable/HorizontalResizable.vue b/argilla-frontend/components/base/base-resizable/HorizontalResizable.vue index 77dccbac2d..19afcdb42d 100644 --- a/argilla-frontend/components/base/base-resizable/HorizontalResizable.vue +++ b/argilla-frontend/components/base/base-resizable/HorizontalResizable.vue @@ -15,6 +15,7 @@ @toggle-expand="toggleExpand" > +
diff --git a/argilla-frontend/components/base/base-resizable/VerticalResizable.vue b/argilla-frontend/components/base/base-resizable/VerticalResizable.vue index c19e5ecad3..bb3b577489 100644 --- a/argilla-frontend/components/base/base-resizable/VerticalResizable.vue +++ b/argilla-frontend/components/base/base-resizable/VerticalResizable.vue @@ -59,8 +59,8 @@ export default { }, methods: { limitElementWidth(element) { - element.style["max-width"] = "65%"; - element.style["min-width"] = "35%"; + element.style["max-width"] = "62%"; + element.style["min-width"] = "38%"; }, savePositionOnStartResizing(e) { this.leftSidePrevPosition = { diff --git a/argilla-frontend/components/base/base-tabs/BaseTabsAndContent.vue b/argilla-frontend/components/base/base-tabs/BaseTabsAndContent.vue index 0cbb3f014c..361bfa964c 100644 --- a/argilla-frontend/components/base/base-tabs/BaseTabsAndContent.vue +++ b/argilla-frontend/components/base/base-tabs/BaseTabsAndContent.vue @@ -49,11 +49,20 @@ export default { return this.currentTab.component; }, }, + watch: { + currentTab() { + this.$emit("onChanged", this.currentTab.id); + }, + }, methods: { getSelectedTab(id) { this.currentTab = this.tabs.find((tab) => tab.id === id); }, }, + mounted() { + this.$emit("onChanged", this.currentTab.id); + + this.$emit("onLoaded"); + }, }; - diff --git a/argilla-frontend/components/features/annotation/container/fields/RecordFieldsHeader.vue b/argilla-frontend/components/features/annotation/container/fields/RecordFieldsHeader.vue index 79268b1776..01669a7a53 100644 --- a/argilla-frontend/components/features/annotation/container/fields/RecordFieldsHeader.vue +++ b/argilla-frontend/components/features/annotation/container/fields/RecordFieldsHeader.vue @@ -24,7 +24,7 @@ :recordCriteria="recordCriteria" :recordId="record.id" /> - + diff --git a/argilla-frontend/components/features/annotation/container/fields/RecordStatus.vue b/argilla-frontend/components/features/annotation/container/fields/RecordStatus.vue index 7334f9d74d..4a80747e9c 100644 --- a/argilla-frontend/components/features/annotation/container/fields/RecordStatus.vue +++ b/argilla-frontend/components/features/annotation/container/fields/RecordStatus.vue @@ -1,47 +1,52 @@ + + diff --git a/argilla-frontend/components/features/annotation/progress/AnnotationProgressDetailed.vue b/argilla-frontend/components/features/annotation/progress/AnnotationProgressDetailed.vue index 59718ea15b..7c4e1491d1 100644 --- a/argilla-frontend/components/features/annotation/progress/AnnotationProgressDetailed.vue +++ b/argilla-frontend/components/features/annotation/progress/AnnotationProgressDetailed.vue @@ -16,31 +16,31 @@ --> diff --git a/argilla-frontend/components/features/annotation/progress/BarProgress.vue b/argilla-frontend/components/features/annotation/progress/BarProgress.vue deleted file mode 100644 index b063a1ffd5..0000000000 --- a/argilla-frontend/components/features/annotation/progress/BarProgress.vue +++ /dev/null @@ -1,67 +0,0 @@ - - - - - diff --git a/argilla-frontend/components/features/annotation/progress/TeamProgress.vue b/argilla-frontend/components/features/annotation/progress/TeamProgress.vue new file mode 100644 index 0000000000..732a396d15 --- /dev/null +++ b/argilla-frontend/components/features/annotation/progress/TeamProgress.vue @@ -0,0 +1,69 @@ + + + + + diff --git a/argilla-frontend/components/features/annotation/progress/useFeedbackTaskProgressViewModel.ts b/argilla-frontend/components/features/annotation/progress/useAnnotationProgressViewModel.ts similarity index 51% rename from argilla-frontend/components/features/annotation/progress/useFeedbackTaskProgressViewModel.ts rename to argilla-frontend/components/features/annotation/progress/useAnnotationProgressViewModel.ts index d7d4febf4c..bec14e00ac 100644 --- a/argilla-frontend/components/features/annotation/progress/useFeedbackTaskProgressViewModel.ts +++ b/argilla-frontend/components/features/annotation/progress/useAnnotationProgressViewModel.ts @@ -7,32 +7,22 @@ import { } from "~/v1/infrastructure/events"; import { useMetrics } from "~/v1/infrastructure/storage/MetricsStorage"; -interface FeedbackTaskProgressProps { +interface AnnotationProgressProps { datasetId: string; - enableFetch?: boolean; } - -export const useFeedbackTaskProgressViewModel = ( - props: FeedbackTaskProgressProps +export const useAnnotationProgressViewModel = ( + props: AnnotationProgressProps ) => { - const { state: datasetMetrics } = useMetrics(); + const { state: metrics } = useMetrics(); const getMetrics = useResolve(GetUserMetricsUseCase); - const loadMetrics = (datasetId: string) => { - getMetrics.execute(datasetId); - }; - onBeforeMount(() => { - if (!props.enableFetch) return; - - useEvents(() => { - new UpdateMetricsEventHandler(); - }); + useEvents(UpdateMetricsEventHandler); - loadMetrics(props.datasetId); + getMetrics.execute(props.datasetId); }); return { - datasetMetrics, + metrics, }; }; diff --git a/argilla-frontend/components/features/annotation/progress/useTeamProgressViewModel.ts b/argilla-frontend/components/features/annotation/progress/useTeamProgressViewModel.ts new file mode 100644 index 0000000000..099836a4ac --- /dev/null +++ b/argilla-frontend/components/features/annotation/progress/useTeamProgressViewModel.ts @@ -0,0 +1,27 @@ +import { useResolve } from "ts-injecty"; +import { onBeforeMount } from "vue-demi"; +import { GetDatasetProgressUseCase } from "~/v1/domain/usecases/get-dataset-progress-use-case"; +import { + useEvents, + UpdateTeamProgressEventHandler, +} from "~/v1/infrastructure/events"; +import { useTeamProgress } from "~/v1/infrastructure/storage/TeamProgressStorage"; + +interface TeamProgressProps { + datasetId: string; +} + +export const useTeamProgressViewModel = (props: TeamProgressProps) => { + const { state: progress } = useTeamProgress(); + const getDatasetProgress = useResolve(GetDatasetProgressUseCase); + + onBeforeMount(() => { + useEvents(UpdateTeamProgressEventHandler); + + getDatasetProgress.execute(props.datasetId); + }); + + return { + progress, + }; +}; diff --git a/argilla-frontend/components/features/annotation/settings/SettingsInfo.vue b/argilla-frontend/components/features/annotation/settings/SettingsInfo.vue index dc36485523..33d1bae137 100644 --- a/argilla-frontend/components/features/annotation/settings/SettingsInfo.vue +++ b/argilla-frontend/components/features/annotation/settings/SettingsInfo.vue @@ -26,10 +26,36 @@ @submit.prevent="onSubmit()" class="settings__edition-form-fields" > +
+

+ +
+
+

+ +