From 04e9e92520c898b9b95dc1e7db9fb5e8a7b0bb25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 13 Jun 2024 17:59:41 +0200 Subject: [PATCH] feat: add DatasetCreateValidator and move some validations from context --- .../src/argilla_server/contexts/datasets.py | 16 +++----- .../src/argilla_server/validators/datasets.py | 37 +++++++++++++++++++ 2 files changed, 43 insertions(+), 10 deletions(-) create mode 100644 argilla-server/src/argilla_server/validators/datasets.py diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index ea2b399424..a31cbcfe61 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -79,6 +79,7 @@ ) from argilla_server.models.suggestions import SuggestionCreateWithRecordId from argilla_server.search_engine import SearchEngine +from argilla_server.validators.datasets import DatasetCreateValidator from argilla_server.validators.responses import ( ResponseCreateValidator, ResponseUpdateValidator, @@ -120,16 +121,7 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) -> async def create_dataset(db: AsyncSession, dataset_attrs: dict): - if await Workspace.get(db, dataset_attrs["workspace_id"]) is None: - raise UnprocessableEntityError(f"Workspace with id `{dataset_attrs['workspace_id']}` not found") - - if await Dataset.get_by(db, name=dataset_attrs["name"], workspace_id=dataset_attrs["workspace_id"]): - raise NotUniqueError( - f"Dataset with name `{dataset_attrs['name']}` already exists for workspace with id `{dataset_attrs['workspace_id']}`" - ) - - return await Dataset.create( - db, + dataset = Dataset( name=dataset_attrs["name"], guidelines=dataset_attrs["guidelines"], allow_extra_metadata=dataset_attrs["allow_extra_metadata"], @@ -137,6 +129,10 @@ async def create_dataset(db: AsyncSession, dataset_attrs: dict): workspace_id=dataset_attrs["workspace_id"], ) + await DatasetCreateValidator.validate(db, dataset) + + return await dataset.save(db) + async def _count_required_fields_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int: return (await db.execute(select(func.count(Field.id)).filter_by(dataset_id=dataset_id, required=True))).scalar_one() diff --git a/argilla-server/src/argilla_server/validators/datasets.py b/argilla-server/src/argilla_server/validators/datasets.py new file mode 100644 index 0000000000..66b783daf7 --- /dev/null +++ b/argilla-server/src/argilla_server/validators/datasets.py @@ -0,0 +1,37 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError +from argilla_server.models import Dataset, Workspace + + +class DatasetCreateValidator: + @classmethod + async def validate(cls, db, dataset: Dataset) -> None: + await cls._validate_workspace_is_present(db, dataset.workspace_id) + await cls._validate_name_is_not_duplicated(db, dataset.name, dataset.workspace_id) + + @classmethod + async def _validate_workspace_is_present(cls, db, workspace_id: UUID) -> None: + if await Workspace.get(db, workspace_id) is None: + raise UnprocessableEntityError(f"Workspace with id `{workspace_id}` not found") + + @classmethod + async def _validate_name_is_not_duplicated(cls, db, name: str, workspace_id: UUID) -> None: + if await Dataset.get_by(db, name=name, workspace_id=workspace_id): + raise NotUniqueError(f"Dataset with name `{name}` already exists for workspace with id `{workspace_id}`")