From 743c945c8dc3de6d458a257b9f4bec977381965f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Fri, 19 Jul 2024 16:30:26 +0200 Subject: [PATCH 01/88] feat: add support to new fields of type image --- .../argilla_server/api/schemas/v1/fields.py | 57 ++++++++- argilla-server/src/argilla_server/enums.py | 1 + argilla-server/tests/factories.py | 12 +- .../fields/test_create_dataset_field.py | 113 ++++++++++++++++++ .../fields/test_list_dataset_fields.py | 79 ++++++++++++ .../handlers/v1/fields/test_update_field.py | 94 +++++++++++++++ 6 files changed, 350 insertions(+), 6 deletions(-) create mode 100644 argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_create_dataset_field.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_list_dataset_fields.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/fields/test_update_field.py diff --git a/argilla-server/src/argilla_server/api/schemas/v1/fields.py b/argilla-server/src/argilla_server/api/schemas/v1/fields.py index d67acc2525..dbd28da5a4 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/fields.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/fields.py @@ -13,12 +13,12 @@ # limitations under the License. from datetime import datetime -from typing import Annotated, List, Literal, Optional +from typing import Annotated, List, Literal, Optional, Union from uuid import UUID from argilla_server.api.schemas.v1.commons import UpdateSchema from argilla_server.enums import FieldType -from argilla_server.pydantic_v1 import BaseModel, constr +from argilla_server.pydantic_v1 import BaseModel, constr, HttpUrl from argilla_server.pydantic_v1 import Field as PydanticField FIELD_CREATE_NAME_REGEX = r"^(?=.*[a-z0-9])[a-z0-9_-]+$" @@ -49,6 +49,11 @@ class TextFieldSettings(BaseModel): + type: Literal[FieldType.text] + use_markdown: bool + + +class TextFieldSettingsCreate(BaseModel): type: Literal[FieldType.text] use_markdown: bool = False @@ -58,12 +63,54 @@ class TextFieldSettingsUpdate(UpdateSchema): use_markdown: bool +class ImageFieldSettings(BaseModel): + type: Literal[FieldType.image] + url: HttpUrl + + +class ImageFieldSettingsCreate(BaseModel): + type: Literal[FieldType.image] + url: HttpUrl + + +class ImageFieldSettingsUpdate(BaseModel): + type: Literal[FieldType.image] + url: HttpUrl + + +FieldSettings = Annotated[ + Union[ + TextFieldSettings, + ImageFieldSettings, + ], + PydanticField(..., discriminator="type"), +] + + +FieldSettingsCreate = Annotated[ + Union[ + TextFieldSettingsCreate, + ImageFieldSettingsCreate, + ], + PydanticField(..., discriminator="type"), +] + + +FieldSettingsUpdate = Annotated[ + Union[ + TextFieldSettingsUpdate, + ImageFieldSettingsUpdate, + ], + PydanticField(..., discriminator="type"), +] + + class Field(BaseModel): id: UUID name: str title: str required: bool - settings: TextFieldSettings + settings: FieldSettings dataset_id: UUID inserted_at: datetime updated_at: datetime @@ -80,11 +127,11 @@ class FieldCreate(BaseModel): name: FieldName title: FieldTitle required: Optional[bool] - settings: TextFieldSettings + settings: FieldSettingsCreate class FieldUpdate(UpdateSchema): title: Optional[FieldTitle] - settings: Optional[TextFieldSettingsUpdate] + settings: Optional[FieldSettingsUpdate] __non_nullable_fields__ = {"title", "settings"} diff --git a/argilla-server/src/argilla_server/enums.py b/argilla-server/src/argilla_server/enums.py index 2edc53d28f..3fc8f35ac8 100644 --- a/argilla-server/src/argilla_server/enums.py +++ b/argilla-server/src/argilla_server/enums.py @@ -17,6 +17,7 @@ class FieldType(str, Enum): text = "text" + image = "image" class ResponseStatus(str, Enum): diff --git a/argilla-server/tests/factories.py b/argilla-server/tests/factories.py index c429fed9af..b5f70d8e44 100644 --- a/argilla-server/tests/factories.py +++ b/argilla-server/tests/factories.py @@ -255,7 +255,17 @@ class Meta: class TextFieldFactory(FieldFactory): - settings = {"type": FieldType.text.value, "use_markdown": False} + settings = { + "type": FieldType.text, + "use_markdown": False, + } + + +class ImageFieldFactory(FieldFactory): + settings = { + "type": FieldType.image, + "url": "https://argilla.io/image.jpeg", + } class MetadataPropertyFactory(BaseFactory): diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_create_dataset_field.py b/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_create_dataset_field.py new file mode 100644 index 0000000000..65e3ce4d39 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_create_dataset_field.py @@ -0,0 +1,113 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from sqlalchemy.ext.asyncio import AsyncSession +from typing_extensions import Any +from uuid import UUID +from datetime import datetime +from httpx import AsyncClient +from sqlalchemy import func, select + +from argilla_server.enums import FieldType +from argilla_server.models import Field + +from tests.factories import DatasetFactory + + +@pytest.mark.asyncio +class TestCreateDatasetField: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}/fields" + + async def test_create_dataset_image_field( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "name": "name", + "title": "title", + "required": True, + "settings": { + "type": FieldType.image, + "url": "https://argilla.io/image.jpeg", + }, + }, + ) + + image_field = (await db.execute(select(Field))).scalar_one() + + assert response.status_code == 201 + assert response.json() == { + "id": str(image_field.id), + "name": "name", + "title": "title", + "required": True, + "settings": { + "type": FieldType.image, + "url": "https://argilla.io/image.jpeg", + }, + "dataset_id": str(dataset.id), + "inserted_at": image_field.inserted_at.isoformat(), + "updated_at": image_field.updated_at.isoformat(), + } + + async def test_create_dataset_image_field_without_url( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "name": "name", + "title": "title", + "required": True, + "settings": { + "type": FieldType.image, + }, + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Field.id)))).scalar_one() == 0 + + @pytest.mark.parametrize("invalid_url", [None, "", " ", "wrong-url", "argilla.io", "http//argilla.io"]) + async def test_create_dataset_image_field_with_invalid_url( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict, invalid_url: Any + ): + dataset = await DatasetFactory.create() + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "name": "name", + "title": "title", + "required": True, + "settings": { + "type": FieldType.image, + "url": invalid_url, + }, + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Field.id)))).scalar_one() == 0 diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_list_dataset_fields.py b/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_list_dataset_fields.py new file mode 100644 index 0000000000..a27f41f1fd --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_list_dataset_fields.py @@ -0,0 +1,79 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient + +from argilla_server.enums import FieldType + +from tests.factories import DatasetFactory, FieldFactory + + +@pytest.mark.asyncio +class TestListDatasetFields: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}/fields" + + async def test_list_dataset_fields_with_image_field(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + image_field_a = await FieldFactory.create( + settings={ + "type": FieldType.image, + "url": "https://argilla.io/image-a.jpeg", + }, + dataset=dataset, + ) + image_field_b = await FieldFactory.create( + settings={ + "type": FieldType.image, + "url": "https://argilla.io/image-b.jpeg", + }, + dataset=dataset, + ) + + response = await async_client.get(self.url(dataset.id), headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "id": str(image_field_a.id), + "name": image_field_a.name, + "title": image_field_a.title, + "required": False, + "settings": { + "type": FieldType.image, + "url": "https://argilla.io/image-a.jpeg", + }, + "dataset_id": str(dataset.id), + "inserted_at": image_field_a.inserted_at.isoformat(), + "updated_at": image_field_a.updated_at.isoformat(), + }, + { + "id": str(image_field_b.id), + "name": image_field_b.name, + "title": image_field_b.title, + "required": False, + "settings": { + "type": FieldType.image, + "url": "https://argilla.io/image-b.jpeg", + }, + "dataset_id": str(dataset.id), + "inserted_at": image_field_b.inserted_at.isoformat(), + "updated_at": image_field_b.updated_at.isoformat(), + }, + ] + } diff --git a/argilla-server/tests/unit/api/handlers/v1/fields/test_update_field.py b/argilla-server/tests/unit/api/handlers/v1/fields/test_update_field.py new file mode 100644 index 0000000000..5d4d67b948 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/fields/test_update_field.py @@ -0,0 +1,94 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from typing_extensions import Any +from uuid import UUID + +from argilla_server.enums import FieldType + +from tests.factories import ImageFieldFactory + + +@pytest.mark.asyncio +class TestUpdateField: + def url(self, field_id: UUID) -> str: + return f"/api/v1/fields/{field_id}" + + async def test_update_image_field(self, async_client: AsyncClient, owner_auth_header: dict): + image_field = await ImageFieldFactory.create() + + response = await async_client.patch( + self.url(image_field.id), + headers=owner_auth_header, + json={ + "settings": { + "type": FieldType.image, + "url": "https://argilla.io/updated-image.jpeg", + }, + }, + ) + + assert response.status_code == 200 + assert response.json() == { + "id": str(image_field.id), + "name": image_field.name, + "title": image_field.title, + "required": False, + "settings": { + "type": FieldType.image, + "url": "https://argilla.io/updated-image.jpeg", + }, + "dataset_id": str(image_field.dataset_id), + "inserted_at": image_field.inserted_at.isoformat(), + "updated_at": image_field.updated_at.isoformat(), + } + + async def test_update_dataset_image_field_without_url(self, async_client: AsyncClient, owner_auth_header: dict): + image_field = await ImageFieldFactory.create() + + response = await async_client.patch( + self.url(image_field.id), + headers=owner_auth_header, + json={ + "settings": { + "type": FieldType.image, + }, + }, + ) + + assert response.status_code == 422 + + @pytest.mark.parametrize("invalid_url", [None, "", " ", "wrong-url", "argilla.io", "http//argilla.io"]) + async def test_update_dataset_image_field_with_invalid_url( + self, async_client: AsyncClient, owner_auth_header: dict, invalid_url: Any + ): + image_field = await ImageFieldFactory.create() + + response = await async_client.patch( + self.url(image_field.id), + headers=owner_auth_header, + json={ + "settings": { + "type": FieldType.image, + "url": invalid_url, + }, + }, + ) + + assert response.status_code == 422 From ecf4ec3baff14bcc5c68c60afffc3c827caa8a07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Fri, 19 Jul 2024 16:38:14 +0200 Subject: [PATCH 02/88] chore: update CHANGELOG.md --- argilla-server/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 3a12ad1e9f..7d9640c56b 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -16,6 +16,12 @@ These are the section headers that we use: ## [Unreleased]() +### Added + +- Added new `image` type dataset field. ([#5279](https://github.com/argilla-io/argilla/pull/5279)) + +## [2.0.0](https://github.com/argilla-io/argilla/compare/v2.0.0rc1...v2.0.0) + > [!IMPORTANT] > This version includes changes related to the search index. So, a reindex is needed. From ff4a913b4248cc778f326146f9eb2b7bbb4360e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Fri, 26 Jul 2024 14:06:08 +0200 Subject: [PATCH 03/88] improvement: validate that records values with associated image fields has correct URLs --- .../argilla_server/api/schemas/v1/fields.py | 7 +- .../src/argilla_server/models/database.py | 9 +++ .../src/argilla_server/validators/records.py | 25 +++++++ argilla-server/tests/factories.py | 1 - .../fields/test_create_dataset_field.py | 54 +------------- .../fields/test_list_dataset_fields.py | 28 ++----- ...py => test_create_dataset_records_bulk.py} | 74 ++++++++++++++++++- .../handlers/v1/fields/test_update_field.py | 46 +----------- .../unit/api/handlers/v1/test_datasets.py | 74 ++++++++++++------- .../unit/database/models/test_field_model.py | 31 ++++++++ 10 files changed, 196 insertions(+), 153 deletions(-) rename argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/{test_create_dataset_records_in_bulk.py => test_create_dataset_records_bulk.py} (77%) create mode 100644 argilla-server/tests/unit/database/models/test_field_model.py diff --git a/argilla-server/src/argilla_server/api/schemas/v1/fields.py b/argilla-server/src/argilla_server/api/schemas/v1/fields.py index dbd28da5a4..3c9726b2bc 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/fields.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/fields.py @@ -18,7 +18,7 @@ from argilla_server.api.schemas.v1.commons import UpdateSchema from argilla_server.enums import FieldType -from argilla_server.pydantic_v1 import BaseModel, constr, HttpUrl +from argilla_server.pydantic_v1 import BaseModel, constr from argilla_server.pydantic_v1 import Field as PydanticField FIELD_CREATE_NAME_REGEX = r"^(?=.*[a-z0-9])[a-z0-9_-]+$" @@ -58,24 +58,21 @@ class TextFieldSettingsCreate(BaseModel): use_markdown: bool = False -class TextFieldSettingsUpdate(UpdateSchema): +class TextFieldSettingsUpdate(BaseModel): type: Literal[FieldType.text] use_markdown: bool class ImageFieldSettings(BaseModel): type: Literal[FieldType.image] - url: HttpUrl class ImageFieldSettingsCreate(BaseModel): type: Literal[FieldType.image] - url: HttpUrl class ImageFieldSettingsUpdate(BaseModel): type: Literal[FieldType.image] - url: HttpUrl FieldSettings = Annotated[ diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 6b9580dbb5..765b157983 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -27,6 +27,7 @@ from argilla_server.api.schemas.v1.questions import QuestionSettings from argilla_server.enums import ( DatasetStatus, + FieldType, MetadataPropertyType, QuestionType, RecordStatus, @@ -73,6 +74,14 @@ class Field(DatabaseModel): __table_args__ = (UniqueConstraint("name", "dataset_id", name="field_name_dataset_id_uq"),) + @property + def is_text(self): + return self.settings.get("type") == FieldType.text + + @property + def is_image(self): + return self.settings.get("type") == FieldType.image + def __repr__(self): return ( f"Field(id={str(self.id)!r}, name={self.name!r}, required={self.required!r}, " diff --git a/argilla-server/src/argilla_server/validators/records.py b/argilla-server/src/argilla_server/validators/records.py index 01a80badba..d20283b4cc 100644 --- a/argilla-server/src/argilla_server/validators/records.py +++ b/argilla-server/src/argilla_server/validators/records.py @@ -13,9 +13,11 @@ # limitations under the License. import copy + from abc import ABC, abstractmethod from typing import Dict, List, Union from uuid import UUID +from urllib.parse import ParseResultBytes, urlparse, ParseResult from sqlalchemy.ext.asyncio import AsyncSession @@ -39,6 +41,7 @@ def _validate_fields(self, dataset: Dataset) -> None: self._validate_required_fields(dataset, fields) self._validate_extra_fields(dataset, fields) + self._validate_image_fields(dataset, fields) def _validate_metadata(self, dataset: Dataset) -> None: metadata = self._record_change.metadata or {} @@ -71,6 +74,27 @@ def _validate_extra_fields(self, dataset: Dataset, fields: Dict[str, str]) -> No if fields_copy: raise UnprocessableEntityError(f"found fields values for non configured fields: {list(fields_copy.keys())}") + def _validate_image_fields(self, dataset: Dataset, fields: Dict[str, str]) -> None: + for field in filter(lambda field: field.is_image, dataset.fields): + if fields.get(field.name) is not None and not self._is_valid_url(fields.get(field.name)): + raise UnprocessableEntityError( + f"image field {field.name!r} has an invalid URL value: {fields.get(field.name)!r}" + ) + + def _is_valid_url(self, url: Union[str, None]) -> bool: + try: + parse_result = urlparse(url) + except ValueError: + return False + + return self._is_valid_web_url(parse_result) or self._is_valid_data_url(parse_result) + + def _is_valid_web_url(self, parse_result: Union[ParseResult, ParseResultBytes]) -> bool: + return all([parse_result.scheme in ["http", "https"], parse_result.netloc, parse_result.path]) + + def _is_valid_data_url(self, parse_result: Union[ParseResult, ParseResultBytes]) -> bool: + return all([parse_result.scheme == "data", parse_result.path]) + class RecordCreateValidator(RecordValidatorBase): def __init__(self, record_create: RecordCreate): @@ -92,6 +116,7 @@ def validate_for(self, dataset: Dataset) -> None: def _validate_duplicated_suggestions(self): if not self._record_change.suggestions: return + question_ids = [s.question_id for s in self._record_change.suggestions] if len(question_ids) != len(set(question_ids)): raise UnprocessableEntityError("found duplicate suggestions question IDs") diff --git a/argilla-server/tests/factories.py b/argilla-server/tests/factories.py index b5f70d8e44..781a8eaf54 100644 --- a/argilla-server/tests/factories.py +++ b/argilla-server/tests/factories.py @@ -264,7 +264,6 @@ class TextFieldFactory(FieldFactory): class ImageFieldFactory(FieldFactory): settings = { "type": FieldType.image, - "url": "https://argilla.io/image.jpeg", } diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_create_dataset_field.py b/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_create_dataset_field.py index 65e3ce4d39..88c7c8341d 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_create_dataset_field.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_create_dataset_field.py @@ -44,10 +44,7 @@ async def test_create_dataset_image_field( "name": "name", "title": "title", "required": True, - "settings": { - "type": FieldType.image, - "url": "https://argilla.io/image.jpeg", - }, + "settings": {"type": FieldType.image}, }, ) @@ -59,55 +56,8 @@ async def test_create_dataset_image_field( "name": "name", "title": "title", "required": True, - "settings": { - "type": FieldType.image, - "url": "https://argilla.io/image.jpeg", - }, + "settings": {"type": FieldType.image}, "dataset_id": str(dataset.id), "inserted_at": image_field.inserted_at.isoformat(), "updated_at": image_field.updated_at.isoformat(), } - - async def test_create_dataset_image_field_without_url( - self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.post( - self.url(dataset.id), - headers=owner_auth_header, - json={ - "name": "name", - "title": "title", - "required": True, - "settings": { - "type": FieldType.image, - }, - }, - ) - - assert response.status_code == 422 - assert (await db.execute(select(func.count(Field.id)))).scalar_one() == 0 - - @pytest.mark.parametrize("invalid_url", [None, "", " ", "wrong-url", "argilla.io", "http//argilla.io"]) - async def test_create_dataset_image_field_with_invalid_url( - self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict, invalid_url: Any - ): - dataset = await DatasetFactory.create() - - response = await async_client.post( - self.url(dataset.id), - headers=owner_auth_header, - json={ - "name": "name", - "title": "title", - "required": True, - "settings": { - "type": FieldType.image, - "url": invalid_url, - }, - }, - ) - - assert response.status_code == 422 - assert (await db.execute(select(func.count(Field.id)))).scalar_one() == 0 diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_list_dataset_fields.py b/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_list_dataset_fields.py index a27f41f1fd..ee657b4876 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_list_dataset_fields.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/fields/test_list_dataset_fields.py @@ -19,7 +19,7 @@ from argilla_server.enums import FieldType -from tests.factories import DatasetFactory, FieldFactory +from tests.factories import DatasetFactory, ImageFieldFactory @pytest.mark.asyncio @@ -29,20 +29,8 @@ def url(self, dataset_id: UUID) -> str: async def test_list_dataset_fields_with_image_field(self, async_client: AsyncClient, owner_auth_header: dict): dataset = await DatasetFactory.create() - image_field_a = await FieldFactory.create( - settings={ - "type": FieldType.image, - "url": "https://argilla.io/image-a.jpeg", - }, - dataset=dataset, - ) - image_field_b = await FieldFactory.create( - settings={ - "type": FieldType.image, - "url": "https://argilla.io/image-b.jpeg", - }, - dataset=dataset, - ) + image_field_a = await ImageFieldFactory.create(dataset=dataset) + image_field_b = await ImageFieldFactory.create(dataset=dataset) response = await async_client.get(self.url(dataset.id), headers=owner_auth_header) @@ -54,10 +42,7 @@ async def test_list_dataset_fields_with_image_field(self, async_client: AsyncCli "name": image_field_a.name, "title": image_field_a.title, "required": False, - "settings": { - "type": FieldType.image, - "url": "https://argilla.io/image-a.jpeg", - }, + "settings": {"type": FieldType.image}, "dataset_id": str(dataset.id), "inserted_at": image_field_a.inserted_at.isoformat(), "updated_at": image_field_a.updated_at.isoformat(), @@ -67,10 +52,7 @@ async def test_list_dataset_fields_with_image_field(self, async_client: AsyncCli "name": image_field_b.name, "title": image_field_b.title, "required": False, - "settings": { - "type": FieldType.image, - "url": "https://argilla.io/image-b.jpeg", - }, + "settings": {"type": FieldType.image}, "dataset_id": str(dataset.id), "inserted_at": image_field_b.inserted_at.isoformat(), "updated_at": image_field_b.updated_at.isoformat(), diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_in_bulk.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py similarity index 77% rename from argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_in_bulk.py rename to argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py index 7110e9ce62..d5d85389ef 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_in_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_bulk.py @@ -29,22 +29,24 @@ RatingQuestionFactory, SpanQuestionFactory, TextFieldFactory, + ImageFieldFactory, TextQuestionFactory, ) @pytest.mark.asyncio -class TestCreateDatasetRecordsInBulk: +class TestCreateDatasetRecordsBulk: def url(self, dataset_id: UUID) -> str: return f"/api/v1/datasets/{dataset_id}/records/bulk" - async def test_create_dataset_records( - self, async_client: AsyncClient, db: AsyncSession, owner: User, owner_auth_header: dict + async def test_create_dataset_records_bulk( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict ): dataset = await DatasetFactory.create(status=DatasetStatus.ready) await TextFieldFactory.create(name="prompt", dataset=dataset) await TextFieldFactory.create(name="response", dataset=dataset) + await ImageFieldFactory.create(name="image", dataset=dataset) text_question = await TextQuestionFactory.create(name="text-question", dataset=dataset) @@ -120,6 +122,7 @@ async def test_create_dataset_records( "fields": { "prompt": "Does exercise help reduce stress?", "response": "Exercise can definitely help reduce stress.", + "image": "https://argilla.io/image.jpeg", }, "external_id": "1", "responses": [ @@ -214,3 +217,68 @@ async def test_create_dataset_records( assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 1 assert (await db.execute(select(func.count(Response.id)))).scalar_one() == 1 assert (await db.execute(select(func.count(Suggestion.id)))).scalar_one() == 6 + + @pytest.mark.parametrize( + "data_url", + [ + "data:image/jpeg;base64,/9j/4QC8RXhpZgAASUkqAAgAAAAGABIBAwABAAAA", + "data:image/webp;base64,UklGRhgCAABXRUJQVlA4WAoAAAAIAAAAHwAAFwA", + ], + ) + async def test_create_dataset_records_bulk_with_data_url_image_field( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict, data_url: str + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + + await ImageFieldFactory.create(name="image", dataset=dataset) + await LabelSelectionQuestionFactory.create(dataset=dataset) + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "image": data_url, + }, + }, + ], + }, + ) + + assert response.status_code == 201 + + assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 1 + + @pytest.mark.parametrize( + "invalid_url", ["http://argilla.io", "https://argilla.io", "http:/argilla.io", "invalid-url", "data:"] + ) + async def test_create_dataset_records_bulk_with_invalid_image_field_url( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict, invalid_url: str + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + + await ImageFieldFactory.create(name="image", dataset=dataset) + await LabelSelectionQuestionFactory.create(dataset=dataset) + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "image": invalid_url, + }, + }, + ], + }, + ) + + assert response.status_code == 422 + assert response.json() == { + "detail": f"record at position 0 is not valid because image field 'image' has an invalid URL value: {invalid_url!r}", + } + + assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 0 diff --git a/argilla-server/tests/unit/api/handlers/v1/fields/test_update_field.py b/argilla-server/tests/unit/api/handlers/v1/fields/test_update_field.py index 5d4d67b948..7c07f7f162 100644 --- a/argilla-server/tests/unit/api/handlers/v1/fields/test_update_field.py +++ b/argilla-server/tests/unit/api/handlers/v1/fields/test_update_field.py @@ -37,10 +37,7 @@ async def test_update_image_field(self, async_client: AsyncClient, owner_auth_he self.url(image_field.id), headers=owner_auth_header, json={ - "settings": { - "type": FieldType.image, - "url": "https://argilla.io/updated-image.jpeg", - }, + "title": "Updated title", }, ) @@ -48,47 +45,10 @@ async def test_update_image_field(self, async_client: AsyncClient, owner_auth_he assert response.json() == { "id": str(image_field.id), "name": image_field.name, - "title": image_field.title, + "title": "Updated title", "required": False, - "settings": { - "type": FieldType.image, - "url": "https://argilla.io/updated-image.jpeg", - }, + "settings": {"type": FieldType.image}, "dataset_id": str(image_field.dataset_id), "inserted_at": image_field.inserted_at.isoformat(), "updated_at": image_field.updated_at.isoformat(), } - - async def test_update_dataset_image_field_without_url(self, async_client: AsyncClient, owner_auth_header: dict): - image_field = await ImageFieldFactory.create() - - response = await async_client.patch( - self.url(image_field.id), - headers=owner_auth_header, - json={ - "settings": { - "type": FieldType.image, - }, - }, - ) - - assert response.status_code == 422 - - @pytest.mark.parametrize("invalid_url", [None, "", " ", "wrong-url", "argilla.io", "http//argilla.io"]) - async def test_update_dataset_image_field_with_invalid_url( - self, async_client: AsyncClient, owner_auth_header: dict, invalid_url: Any - ): - image_field = await ImageFieldFactory.create() - - response = await async_client.patch( - self.url(image_field.id), - headers=owner_auth_header, - json={ - "settings": { - "type": FieldType.image, - "url": invalid_url, - }, - }, - ) - - assert response.status_code == 422 diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index a259baa773..0703710d6b 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -2037,22 +2037,31 @@ async def test_create_dataset_records_with_wrong_value_field( await TextQuestionFactory.create(name="input_ok", dataset=dataset) await TextQuestionFactory.create(name="output_ok", dataset=dataset) - records_json = { - "items": [ - { - "fields": {"input": "Say Hello", "output": 33}, - }, - { - "fields": {"input": "Say Hello", "output": "Hi"}, - }, - { - "fields": {"input": "Say Pello", "output": "Hello World"}, - }, - ] - } - response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "input": "Say Hello", + "output": 33, + }, + }, + { + "fields": { + "input": "Say Hello", + "output": "Hi", + }, + }, + { + "fields": { + "input": "Say Pello", + "output": "Hello World", + }, + }, + ], + }, ) assert response.status_code == 422 @@ -2065,7 +2074,7 @@ async def test_create_dataset_records_with_wrong_value_field( "loc": ["body", "items", 0, "fields", "output"], "msg": "str type expected", "type": "type_error.str", - } + }, ] }, } @@ -2081,16 +2090,29 @@ async def test_create_dataset_records_with_extra_fields( await TextQuestionFactory.create(name="input_ok", dataset=dataset) await TextQuestionFactory.create(name="output_ok", dataset=dataset) - records_json = { - "items": [ - {"fields": {"input": "Say Hello", "output": "unexpected"}}, - {"fields": {"input": "Say Hello"}}, - {"fields": {"input": "Say Pello"}}, - ] - } - response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "input": "Say Hello", + "output": "unexpected", + }, + }, + { + "fields": { + "input": "Say Hello", + }, + }, + { + "fields": { + "input": "Say Pello", + }, + }, + ], + }, ) assert response.status_code == 422 @@ -2152,7 +2174,7 @@ async def test_create_dataset_records_with_wrong_optional_fields( "loc": ["body", "items", 0, "fields", "output"], "msg": "str type expected", "type": "type_error.str", - } + }, ] }, } diff --git a/argilla-server/tests/unit/database/models/test_field_model.py b/argilla-server/tests/unit/database/models/test_field_model.py new file mode 100644 index 0000000000..3969c5d2ad --- /dev/null +++ b/argilla-server/tests/unit/database/models/test_field_model.py @@ -0,0 +1,31 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from argilla_server.models import Field +from argilla_server.enums import FieldType + + +@pytest.mark.asyncio +class TestFieldModel: + def test_is_text_property(self): + assert Field(settings={"type": FieldType.text}).is_text == True + assert Field(settings={"type": FieldType.image}).is_text == False + assert Field(settings={}).is_text == False + + def test_is_image_property(self): + assert Field(settings={"type": FieldType.image}).is_image == True + assert Field(settings={"type": FieldType.text}).is_image == False + assert Field(settings={}).is_image == False From 047e8b02f9a1e9690273e46d47a44a410a333d81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Fri, 26 Jul 2024 14:15:07 +0200 Subject: [PATCH 04/88] chore: update CHANGELOG.md --- argilla-server/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 372e3ba18e..3fe51a37f6 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -18,7 +18,7 @@ These are the section headers that we use: ### Added -- Added new `image` type dataset field. ([#5279](https://github.com/argilla-io/argilla/pull/5279)) +- Added new `image` type dataset field supporting URLs and Data URLs. ([#5279](https://github.com/argilla-io/argilla/pull/5279)) ## [2.0.0](https://github.com/argilla-io/argilla/compare/v2.0.0rc1...v2.0.0) From 183813fa2334b38f7c988cc85bdafa8b54b6cf6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Fri, 26 Jul 2024 16:33:04 +0200 Subject: [PATCH 05/88] chore: improve import --- argilla-server/src/argilla_server/validators/records.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/argilla-server/src/argilla_server/validators/records.py b/argilla-server/src/argilla_server/validators/records.py index d20283b4cc..9f2d6961d9 100644 --- a/argilla-server/src/argilla_server/validators/records.py +++ b/argilla-server/src/argilla_server/validators/records.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Union from uuid import UUID -from urllib.parse import ParseResultBytes, urlparse, ParseResult +from urllib.parse import urlparse, ParseResult, ParseResultBytes from sqlalchemy.ext.asyncio import AsyncSession From 1eecbeb0b88547330495ddebb0eed85f77f43c97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dami=C3=A1n=20Pumar?= Date: Wed, 31 Jul 2024 12:38:53 +0200 Subject: [PATCH 06/88] =?UTF-8?q?=F0=9F=93=9D=20Html=20sandbox?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../container/fields/text-field/Sandbox.vue | 45 ++++++++ .../fields/text-field/TextField.component.vue | 103 +++++++++++++++++- 2 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 argilla-frontend/components/features/annotation/container/fields/text-field/Sandbox.vue diff --git a/argilla-frontend/components/features/annotation/container/fields/text-field/Sandbox.vue b/argilla-frontend/components/features/annotation/container/fields/text-field/Sandbox.vue new file mode 100644 index 0000000000..429dae2660 --- /dev/null +++ b/argilla-frontend/components/features/annotation/container/fields/text-field/Sandbox.vue @@ -0,0 +1,45 @@ +