diff --git a/argilla/src/argilla/_api/_questions.py b/argilla/src/argilla/_api/_questions.py index 5b112bc76f..98d67eb6da 100644 --- a/argilla/src/argilla/_api/_questions.py +++ b/argilla/src/argilla/_api/_questions.py @@ -18,34 +18,16 @@ import httpx from argilla._api._base import ResourceAPI from argilla._exceptions import api_error_handler -from argilla._models import ( - TextQuestionModel, - LabelQuestionModel, - MultiLabelQuestionModel, - RankingQuestionModel, - RatingQuestionModel, - SpanQuestionModel, - QuestionBaseModel, - QuestionModel, -) +from argilla._models import QuestionModel __all__ = ["QuestionsAPI"] -class QuestionsAPI(ResourceAPI[QuestionBaseModel]): +class QuestionsAPI(ResourceAPI[QuestionModel]): """Manage datasets via the API""" http_client: httpx.Client - _TYPE_TO_MODEL_CLASS = { - "text": TextQuestionModel, - "label_selection": LabelQuestionModel, - "multi_label_selection": MultiLabelQuestionModel, - "ranking": RankingQuestionModel, - "rating": RatingQuestionModel, - "span": SpanQuestionModel, - } - ################ # CRUD methods # ################ @@ -53,15 +35,14 @@ class QuestionsAPI(ResourceAPI[QuestionBaseModel]): @api_error_handler def create( self, - dataset_id: UUID, question: QuestionModel, ) -> QuestionModel: - url = f"/api/v1/datasets/{dataset_id}/questions" + url = f"/api/v1/datasets/{question.dataset_id}/questions" response = self.http_client.post(url=url, json=question.model_dump()) response.raise_for_status() response_json = response.json() question_model = self._model_from_json(response_json=response_json) - self._log_message(message=f"Created question {question_model.name} in dataset {dataset_id}") + self._log_message(message=f"Created question {question_model.name} in dataset {question.dataset_id}") return question_model @api_error_handler @@ -69,25 +50,24 @@ def update( self, question: QuestionModel, ) -> QuestionModel: - # TODO: Implement update method for fields with server side ID - raise NotImplementedError + url = f"/api/v1/questions/{question.id}" + response = self.http_client.patch(url, json=question.model_dump()) + response.raise_for_status() + response_json = response.json() + updated_question = self._model_from_json(response_json) + self._log_message(message=f"Update question {updated_question.name} with id {question.id}") + return updated_question @api_error_handler def delete(self, question_id: UUID) -> None: - # TODO: Implement delete method for fields with server side ID - raise NotImplementedError + url = f"/api/v1/questions/{question_id}" + self.http_client.delete(url).raise_for_status() + self._log_message(message=f"Deleted question with id {question_id}") #################### # Utility methods # #################### - def create_many(self, dataset_id: UUID, questions: List[QuestionModel]) -> List[QuestionModel]: - response_models = [] - for question in questions: - response_model = self.create(dataset_id=dataset_id, question=question) - response_models.append(response_model) - return response_models - @api_error_handler def list(self, dataset_id: UUID) -> List[QuestionModel]: response = self.http_client.get(f"/api/v1/datasets/{dataset_id}/questions") @@ -103,21 +83,7 @@ def list(self, dataset_id: UUID) -> List[QuestionModel]: def _model_from_json(self, response_json: Dict) -> QuestionModel: response_json["inserted_at"] = self._date_from_iso_format(date=response_json["inserted_at"]) response_json["updated_at"] = self._date_from_iso_format(date=response_json["updated_at"]) - return self._get_model_from_response(response_json=response_json) + return QuestionModel(**response_json) def _model_from_jsons(self, response_jsons: List[Dict]) -> List[QuestionModel]: return list(map(self._model_from_json, response_jsons)) - - def _get_model_from_response(self, response_json: Dict) -> QuestionModel: - """Get the model from the response""" - try: - question_type = response_json.get("settings", {}).get("type") - except Exception as e: - raise ValueError("Invalid field type: missing 'settings.type' in response") from e - - question_class = self._TYPE_TO_MODEL_CLASS.get(question_type) - if question_class is None: - self._log_message(message=f"Unknown question type: {question_type}") - question_class = QuestionBaseModel - - return question_class(**response_json, check_fields=False) diff --git a/argilla/src/argilla/_models/__init__.py b/argilla/src/argilla/_models/__init__.py index 553296d6dd..4f69b93024 100644 --- a/argilla/src/argilla/_models/__init__.py +++ b/argilla/src/argilla/_models/__init__.py @@ -39,18 +39,14 @@ FieldSettings, ) from argilla._models._settings._questions import ( - LabelQuestionModel, - LabelQuestionSettings, - MultiLabelQuestionModel, - QuestionBaseModel, QuestionModel, QuestionSettings, - RankingQuestionModel, - RatingQuestionModel, - SpanQuestionModel, SpanQuestionSettings, - TextQuestionModel, TextQuestionSettings, + LabelQuestionSettings, + RatingQuestionSettings, + MultiLabelQuestionSettings, + RankingQuestionSettings, ) from argilla._models._settings._metadata import ( MetadataFieldModel, @@ -61,5 +57,18 @@ FloatMetadataPropertySettings, IntegerMetadataPropertySettings, ) +from argilla._models._settings._questions import ( + QuestionModel, + QuestionSettings, + LabelQuestionSettings, + RatingQuestionSettings, + TextQuestionSettings, + MultiLabelQuestionSettings, + RankingQuestionSettings, + SpanQuestionSettings, +) from argilla._models._settings._vectors import VectorFieldModel + +from argilla._models._user import UserModel, Role +from argilla._models._workspace import WorkspaceModel from argilla._models._webhook import WebhookModel, EventType diff --git a/argilla/src/argilla/_models/_settings/_questions.py b/argilla/src/argilla/_models/_settings/_questions.py new file mode 100644 index 0000000000..558b351f23 --- /dev/null +++ b/argilla/src/argilla/_models/_settings/_questions.py @@ -0,0 +1,164 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Annotated, Union, Optional, ClassVar, List, Dict, Literal +from uuid import UUID + +from pydantic import ConfigDict, field_validator, Field, BaseModel, model_validator, field_serializer +from pydantic_core.core_schema import ValidationInfo + +from argilla._models import ResourceModel + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +class LabelQuestionSettings(BaseModel): + type: Literal["label_selection"] = "label_selection" + + _MIN_VISIBLE_OPTIONS: ClassVar[int] = 3 + + options: List[Dict[str, Optional[str]]] = Field(default_factory=list, validate_default=True) + visible_options: Optional[int] = Field(None, validate_default=True, ge=_MIN_VISIBLE_OPTIONS) + + @field_validator("options", mode="before") + @classmethod + def __labels_are_unique(cls, options: List[Dict[str, Optional[str]]]) -> List[Dict[str, Optional[str]]]: + """Ensure that labels are unique""" + + unique_labels = list(set([option["value"] for option in options])) + if len(unique_labels) != len(options): + raise ValueError("All labels must be unique") + return options + + @model_validator(mode="after") + def __validate_visible_options(self) -> "Self": + if self.visible_options is None and self.options and len(self.options) >= self._MIN_VISIBLE_OPTIONS: + self.visible_options = len(self.options) + return self + + +class MultiLabelQuestionSettings(LabelQuestionSettings): + type: Literal["multi_label_selection"] = "multi_label_selection" + options_order: Literal["natural", "suggestion"] = Field("natural", description="The order of the labels in the UI.") + + +class RankingQuestionSettings(BaseModel): + type: Literal["ranking"] = "ranking" + + options: List[Dict[str, Optional[str]]] = Field(default_factory=list, validate_default=True) + + @field_validator("options", mode="before") + @classmethod + def __values_are_unique(cls, options: List[Dict[str, Optional[str]]]) -> List[Dict[str, Optional[str]]]: + """Ensure that values are unique""" + + unique_values = list(set([option["value"] for option in options])) + if len(unique_values) != len(options): + raise ValueError("All values must be unique") + + return options + + +class RatingQuestionSettings(BaseModel): + type: Literal["rating"] = "rating" + + options: List[dict] = Field(..., validate_default=True) + + @field_validator("options", mode="before") + @classmethod + def __values_are_unique(cls, options: List[dict]) -> List[dict]: + """Ensure that values are unique""" + + unique_values = list(set([option["value"] for option in options])) + if len(unique_values) != len(options): + raise ValueError("All values must be unique") + + return options + + +class SpanQuestionSettings(BaseModel): + type: Literal["span"] = "span" + + _MIN_VISIBLE_OPTIONS: ClassVar[int] = 3 + + allow_overlapping: bool = False + field: Optional[str] = None + options: List[Dict[str, Optional[str]]] = Field(default_factory=list, validate_default=True) + visible_options: Optional[int] = Field(None, validate_default=True, ge=_MIN_VISIBLE_OPTIONS) + + @field_validator("options", mode="before") + @classmethod + def __values_are_unique(cls, options: List[Dict[str, Optional[str]]]) -> List[Dict[str, Optional[str]]]: + """Ensure that values are unique""" + + unique_values = list(set([option["value"] for option in options])) + if len(unique_values) != len(options): + raise ValueError("All values must be unique") + + return options + + @model_validator(mode="after") + def __validate_visible_options(self) -> "Self": + if self.visible_options is None and self.options and len(self.options) >= self._MIN_VISIBLE_OPTIONS: + self.visible_options = len(self.options) + return self + + +class TextQuestionSettings(BaseModel): + type: Literal["text"] = "text" + + use_markdown: bool = False + + +QuestionSettings = Annotated[ + Union[ + LabelQuestionSettings, + MultiLabelQuestionSettings, + RankingQuestionSettings, + RatingQuestionSettings, + SpanQuestionSettings, + TextQuestionSettings, + ], + Field(..., discriminator="type"), +] + + +class QuestionModel(ResourceModel): + name: str + settings: QuestionSettings + + title: str = Field(None, validate_default=True) + description: Optional[str] = None + required: bool = True + + dataset_id: Optional[UUID] = None + + @field_validator("title", mode="before") + @classmethod + def _title_default(cls, title, info: ValidationInfo): + validated_title = title or info.data["name"] + return validated_title + + @property + def type(self) -> str: + return self.settings.type + + @field_serializer("id", "dataset_id", when_used="unless-none") + def serialize_id(self, value: UUID) -> str: + return str(value) + + model_config = ConfigDict(validate_assignment=True) diff --git a/argilla/src/argilla/_models/_settings/_questions/__init__.py b/argilla/src/argilla/_models/_settings/_questions/__init__.py deleted file mode 100644 index 403774c032..0000000000 --- a/argilla/src/argilla/_models/_settings/_questions/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# flake8: noqa -from typing import Union - -from argilla._models._settings._questions._label_selection import LabelQuestionModel, LabelQuestionSettings -from argilla._models._settings._questions._multi_label_selection import ( - MultiLabelQuestionModel, - MultiLabelQuestionSettings, -) -from argilla._models._settings._questions._rating import RatingQuestionModel, RatingQuestionSettings -from argilla._models._settings._questions._ranking import RankingQuestionModel, RankingQuestionSettings -from argilla._models._settings._questions._text import TextQuestionModel, TextQuestionSettings -from argilla._models._settings._questions._base import QuestionBaseModel, QuestionSettings -from argilla._models._settings._questions._span import SpanQuestionModel, SpanQuestionSettings - -QuestionModel = Union[ - LabelQuestionModel, - RatingQuestionModel, - TextQuestionModel, - MultiLabelQuestionModel, - RankingQuestionModel, - QuestionBaseModel, -] diff --git a/argilla/src/argilla/_models/_settings/_questions/_base.py b/argilla/src/argilla/_models/_settings/_questions/_base.py deleted file mode 100644 index e661689507..0000000000 --- a/argilla/src/argilla/_models/_settings/_questions/_base.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2024-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime -from typing import Optional -from uuid import UUID - -from pydantic import BaseModel, field_serializer, field_validator, Field -from pydantic_core.core_schema import ValidationInfo - - -class QuestionSettings(BaseModel, validate_assignment=True): - type: str - - -class QuestionBaseModel(BaseModel, validate_assignment=True): - id: Optional[UUID] = None - name: str - settings: QuestionSettings - - title: str = Field(None, validate_default=True) - description: Optional[str] = None - required: bool = True - inserted_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - - @field_validator("title", mode="before") - @classmethod - def __title_default(cls, title, info: ValidationInfo): - validated_title = title or info.data["name"] - return validated_title - - @field_serializer("inserted_at", "updated_at", when_used="unless-none") - def serialize_datetime(self, value: datetime) -> str: - return value.isoformat() - - @field_serializer("id", when_used="unless-none") - def serialize_id(self, value: UUID) -> str: - return str(value) diff --git a/argilla/src/argilla/_models/_settings/_questions/_label_selection.py b/argilla/src/argilla/_models/_settings/_questions/_label_selection.py deleted file mode 100644 index 358bf441e7..0000000000 --- a/argilla/src/argilla/_models/_settings/_questions/_label_selection.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2024-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List, Optional, ClassVar - -from pydantic import field_validator, Field, model_validator - -from argilla._models._settings._questions._base import QuestionSettings, QuestionBaseModel - -try: - from typing import Self -except ImportError: - from typing_extensions import Self - - -class LabelQuestionSettings(QuestionSettings): - type: str = "label_selection" - - _MIN_VISIBLE_OPTIONS: ClassVar[int] = 3 - - options: List[Dict[str, Optional[str]]] = Field(default_factory=list, validate_default=True) - visible_options: Optional[int] = Field(None, validate_default=True, ge=_MIN_VISIBLE_OPTIONS) - - @field_validator("options", mode="before") - @classmethod - def __labels_are_unique(cls, options: List[Dict[str, Optional[str]]]) -> List[Dict[str, Optional[str]]]: - """Ensure that labels are unique""" - - unique_labels = list(set([option["value"] for option in options])) - if len(unique_labels) != len(options): - raise ValueError("All labels must be unique") - return options - - @model_validator(mode="after") - def __validate_visible_options(self) -> "Self": - if self.visible_options is None and self.options and len(self.options) >= self._MIN_VISIBLE_OPTIONS: - self.visible_options = len(self.options) - return self - - -class LabelQuestionModel(QuestionBaseModel): - settings: LabelQuestionSettings diff --git a/argilla/src/argilla/_models/_settings/_questions/_multi_label_selection.py b/argilla/src/argilla/_models/_settings/_questions/_multi_label_selection.py deleted file mode 100644 index 8eeeb7f121..0000000000 --- a/argilla/src/argilla/_models/_settings/_questions/_multi_label_selection.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum - -from pydantic import Field - -from argilla._models._settings._questions._label_selection import LabelQuestionSettings, LabelQuestionModel - - -class OptionsOrder(str, Enum): - natural = "natural" - suggestion = "suggestion" - - -class MultiLabelQuestionSettings(LabelQuestionSettings): - type: str = "multi_label_selection" - options_order: OptionsOrder = Field(OptionsOrder.natural, description="The order of the labels in the UI.") - - -class MultiLabelQuestionModel(LabelQuestionModel): - settings: MultiLabelQuestionSettings diff --git a/argilla/src/argilla/_models/_settings/_questions/_ranking.py b/argilla/src/argilla/_models/_settings/_questions/_ranking.py deleted file mode 100644 index 6adb9aebac..0000000000 --- a/argilla/src/argilla/_models/_settings/_questions/_ranking.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2024-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List, Optional - -from pydantic import field_validator, Field - -from argilla._models._settings._questions._base import QuestionSettings, QuestionBaseModel - - -class RankingQuestionSettings(QuestionSettings): - type: str = "ranking" - - options: List[Dict[str, Optional[str]]] = Field(default_factory=list, validate_default=True) - - @field_validator("options", mode="before") - @classmethod - def __values_are_unique(cls, options: List[Dict[str, Optional[str]]]) -> List[Dict[str, Optional[str]]]: - """Ensure that values are unique""" - - unique_values = list(set([option["value"] for option in options])) - if len(unique_values) != len(options): - raise ValueError("All values must be unique") - - return options - - -class RankingQuestionModel(QuestionBaseModel): - settings: RankingQuestionSettings diff --git a/argilla/src/argilla/_models/_settings/_questions/_rating.py b/argilla/src/argilla/_models/_settings/_questions/_rating.py deleted file mode 100644 index 9248bf3ca8..0000000000 --- a/argilla/src/argilla/_models/_settings/_questions/_rating.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2024-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -from pydantic import field_validator, Field - -from argilla._models._settings._questions._base import QuestionSettings, QuestionBaseModel - - -class RatingQuestionSettings(QuestionSettings): - type: str = "rating" - - options: List[dict] = Field(..., validate_default=True) - - @field_validator("options", mode="before") - @classmethod - def __values_are_unique(cls, options: List[dict]) -> List[dict]: - """Ensure that values are unique""" - - unique_values = list(set([option["value"] for option in options])) - if len(unique_values) != len(options): - raise ValueError("All values must be unique") - - return options - - -class RatingQuestionModel(QuestionBaseModel): - settings: RatingQuestionSettings diff --git a/argilla/src/argilla/_models/_settings/_questions/_span.py b/argilla/src/argilla/_models/_settings/_questions/_span.py deleted file mode 100644 index a24b9e1059..0000000000 --- a/argilla/src/argilla/_models/_settings/_questions/_span.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2024-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List, Optional, ClassVar - -from pydantic import field_validator, Field, model_validator - -from argilla._models._settings._questions._base import QuestionSettings, QuestionBaseModel - -try: - from typing import Self -except ImportError: - from typing_extensions import Self - - -class SpanQuestionSettings(QuestionSettings): - type: str = "span" - - _MIN_VISIBLE_OPTIONS: ClassVar[int] = 3 - - allow_overlapping: bool = False - field: Optional[str] = None - options: List[Dict[str, Optional[str]]] = Field(default_factory=list, validate_default=True) - visible_options: Optional[int] = Field(None, validate_default=True, ge=_MIN_VISIBLE_OPTIONS) - - @field_validator("options", mode="before") - @classmethod - def __values_are_unique(cls, options: List[Dict[str, Optional[str]]]) -> List[Dict[str, Optional[str]]]: - """Ensure that values are unique""" - - unique_values = list(set([option["value"] for option in options])) - if len(unique_values) != len(options): - raise ValueError("All values must be unique") - - return options - - @model_validator(mode="after") - def __validate_visible_options(self) -> "Self": - if self.visible_options is None and self.options and len(self.options) >= self._MIN_VISIBLE_OPTIONS: - self.visible_options = len(self.options) - return self - - -class SpanQuestionModel(QuestionBaseModel): - settings: SpanQuestionSettings diff --git a/argilla/src/argilla/_models/_settings/_questions/_text.py b/argilla/src/argilla/_models/_settings/_questions/_text.py deleted file mode 100644 index 86d4a43f12..0000000000 --- a/argilla/src/argilla/_models/_settings/_questions/_text.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2024-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from argilla._models._settings._questions._base import QuestionSettings, QuestionBaseModel - - -class TextQuestionSettings(QuestionSettings): - type: str = "text" - - use_markdown: bool = False - - -class TextQuestionModel(QuestionBaseModel): - settings: TextQuestionSettings diff --git a/argilla/src/argilla/records/_mapping/_mapper.py b/argilla/src/argilla/records/_mapping/_mapper.py index a4c4a398a8..be65717ab0 100644 --- a/argilla/src/argilla/records/_mapping/_mapper.py +++ b/argilla/src/argilla/records/_mapping/_mapper.py @@ -22,7 +22,7 @@ from argilla.responses import Response from argilla.settings import FieldBase, VectorField from argilla.settings._metadata import MetadataPropertyBase -from argilla.settings._question import QuestionPropertyBase +from argilla.settings._question import QuestionBase from argilla.suggestions import Suggestion from argilla.records._mapping._routes import ( AttributeRoute, @@ -177,12 +177,12 @@ def _select_attribute_type(self, attribute_route: AttributeRoute) -> AttributeRo If the attribute type is not provided, it will be inferred based on the schema item. """ schema_item = self._schema.get(attribute_route.name) - if isinstance(schema_item, QuestionPropertyBase) and ( + if isinstance(schema_item, QuestionBase) and ( attribute_route.type is None or attribute_route.type == AttributeType.SUGGESTION ): # Suggestions are the default destination for questions. attribute_route.type = AttributeType.SUGGESTION - elif isinstance(schema_item, QuestionPropertyBase) and attribute_route.type == AttributeType.RESPONSE: + elif isinstance(schema_item, QuestionBase) and attribute_route.type == AttributeType.RESPONSE: attribute_route.type = AttributeType.RESPONSE elif isinstance(schema_item, FieldBase): attribute_route.type = AttributeType.FIELD diff --git a/argilla/src/argilla/settings/_common.py b/argilla/src/argilla/settings/_common.py index b5760d1f78..be3f943c0e 100644 --- a/argilla/src/argilla/settings/_common.py +++ b/argilla/src/argilla/settings/_common.py @@ -14,7 +14,7 @@ from typing import Any, Optional, Union -from argilla._models import FieldModel, QuestionBaseModel +from argilla._models import FieldModel, QuestionModel from argilla._resource import Resource __all__ = ["SettingsPropertyBase"] @@ -23,7 +23,7 @@ class SettingsPropertyBase(Resource): """Base class for dataset fields or questions in Settings class""" - _model: Union[FieldModel, QuestionBaseModel] + _model: Union[FieldModel, QuestionModel] def __repr__(self) -> str: return ( diff --git a/argilla/src/argilla/settings/_question.py b/argilla/src/argilla/settings/_question.py index 262dddf1c8..63fb19f208 100644 --- a/argilla/src/argilla/settings/_question.py +++ b/argilla/src/argilla/settings/_question.py @@ -12,26 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union, TYPE_CHECKING from argilla import Argilla +from argilla._api import QuestionsAPI from argilla._models._settings._questions import ( - LabelQuestionModel, + QuestionModel, + QuestionSettings, LabelQuestionSettings, - MultiLabelQuestionModel, MultiLabelQuestionSettings, - QuestionModel, - RankingQuestionModel, - RankingQuestionSettings, - RatingQuestionModel, + TextQuestionSettings, RatingQuestionSettings, - SpanQuestionModel, + RankingQuestionSettings, SpanQuestionSettings, - TextQuestionModel, - TextQuestionSettings, ) from argilla.settings._common import SettingsPropertyBase +if TYPE_CHECKING: + from argilla.datasets import Dataset + try: from typing import Self except ImportError: @@ -48,7 +47,62 @@ ] -class QuestionPropertyBase(SettingsPropertyBase): +class QuestionBase(SettingsPropertyBase): + _model: QuestionModel + _api: QuestionsAPI + _dataset: Optional["Dataset"] + + def __init__( + self, + name: str, + settings: QuestionSettings, + title: Optional[str] = None, + required: Optional[bool] = True, + description: Optional[str] = None, + _client: Optional[Argilla] = None, + ): + client = _client or Argilla._get_default() + + super().__init__(api=client.api.questions, client=client) + + self._dataset = None + self._model = QuestionModel( + name=name, + settings=settings, + title=title, + required=required, + description=description, + ) + + @classmethod + def from_model(cls, model: QuestionModel) -> "Self": + instance = cls(name=model.name) # noqa + instance._model = model + + return instance + + @classmethod + def from_dict(cls, data: dict) -> "Self": + model = QuestionModel(**data) + return cls.from_model(model) + + @property + def dataset(self) -> "Dataset": + return self._dataset + + @dataset.setter + def dataset(self, value: "Dataset") -> None: + self._dataset = value + self._model.dataset_id = self._dataset.id + self._with_client(self._dataset._client) + + def _with_client(self, client: "Argilla") -> "Self": + # TODO: Review and simplify. Maybe only one of them is required + self._client = client + self._api = self._client.api.questions + + return self + @staticmethod def _render_values_as_options(values: Union[List[str], List[int], Dict[str, str]]) -> List[Dict[str, str]]: """Render values as options for the question so that the model conforms to the API schema""" @@ -79,16 +133,8 @@ def _render_options_as_labels(cls, options: List[Dict[str, str]]) -> List[str]: """Render values as labels for the question so that they can be returned as a list of strings""" return list(cls._render_options_as_values(options=options).keys()) - def _with_client(self, client: "Argilla") -> "Self": - self._client = client - self._api = client.api.questions - - return self - - -class LabelQuestion(QuestionPropertyBase): - _model: LabelQuestionModel +class LabelQuestion(QuestionBase): def __init__( self, name: str, @@ -97,6 +143,7 @@ def __init__( description: Optional[str] = None, required: bool = True, visible_labels: Optional[int] = None, + client: Optional[Argilla] = None, ) -> None: """ Define a new label question for `Settings` of a `Dataset`. A label \ question is a question where the user can select one label from \ @@ -112,27 +159,19 @@ def __init__( visible_labels (Optional[int]): The number of visible labels for the question to be shown in the UI. \ Setting it to None show all options. """ - self._model = LabelQuestionModel( + + super().__init__( name=name, title=title, - description=description, required=required, + description=description, settings=LabelQuestionSettings( - options=self._render_values_as_options(labels), visible_options=visible_labels + options=self._render_values_as_options(labels), + visible_options=visible_labels, ), + _client=client, ) - @classmethod - def from_model(cls, model: LabelQuestionModel) -> "LabelQuestion": - instance = cls(name=model.name, labels=cls._render_options_as_values(model.settings.options)) - instance._model = model - return instance - - @classmethod - def from_dict(cls, data: dict) -> "LabelQuestion": - model = LabelQuestionModel(**data) - return cls.from_model(model=model) - ############################## # Public properties ############################## @@ -153,14 +192,15 @@ def visible_labels(self) -> Optional[int]: def visible_labels(self, visible_labels: Optional[int]) -> None: self._model.settings.visible_options = visible_labels - ############################## - # Private methods - ############################## + @classmethod + def from_model(cls, model: QuestionModel) -> "Self": + instance = cls(name=model.name, labels=cls._render_options_as_labels(model.settings.options)) # noqa + instance._model = model + return instance -class MultiLabelQuestion(LabelQuestion): - _model: MultiLabelQuestionModel +class MultiLabelQuestion(LabelQuestion): def __init__( self, name: str, @@ -170,6 +210,7 @@ def __init__( title: Optional[str] = None, description: Optional[str] = None, required: bool = True, + client: Optional[Argilla] = None, ) -> None: """Create a new multi-label question for `Settings` of a `Dataset`. A \ multi-label question is a question where the user can select multiple \ @@ -188,38 +229,29 @@ def __init__( description (Optional[str]): The description of the question to be shown in the UI. required (bool): If the question is required for a record to be valid. At least one question must be required. """ - self._model = MultiLabelQuestionModel( + QuestionBase.__init__( + self, name=name, title=title, - description=description, required=required, + description=description, settings=MultiLabelQuestionSettings( options=self._render_values_as_options(labels), visible_options=visible_labels, options_order=labels_order, ), + _client=client, ) @classmethod - def from_model(cls, model: MultiLabelQuestionModel) -> "MultiLabelQuestion": - instance = cls( - name=model.name, - labels=cls._render_options_as_values(model.settings.options), - labels_order=model.settings.options_order, - ) + def from_model(cls, model: QuestionModel) -> "Self": + instance = cls(name=model.name, labels=cls._render_options_as_labels(model.settings.options)) # noqa instance._model = model return instance - @classmethod - def from_dict(cls, data: dict) -> "MultiLabelQuestion": - model = MultiLabelQuestionModel(**data) - return cls.from_model(model=model) - - -class TextQuestion(QuestionPropertyBase): - _model: TextQuestionModel +class TextQuestion(QuestionBase): def __init__( self, name: str, @@ -227,6 +259,7 @@ def __init__( description: Optional[str] = None, required: bool = True, use_markdown: bool = False, + client: Optional[Argilla] = None, ) -> None: """Create a new text question for `Settings` of a `Dataset`. A text question \ is a question where the user can input text. @@ -239,26 +272,15 @@ def __init__( use_markdown (Optional[bool]): Whether to render the markdown in the UI. When True, you will be able \ to use all the Markdown features for text formatting, including LaTex formulas and embedding multimedia content and PDFs. """ - self._model = TextQuestionModel( + super().__init__( name=name, title=title, - description=description, required=required, + description=description, settings=TextQuestionSettings(use_markdown=use_markdown), + _client=client, ) - @classmethod - def from_model(cls, model: TextQuestionModel) -> "TextQuestion": - instance = cls(name=model.name) - instance._model = model - - return instance - - @classmethod - def from_dict(cls, data: dict) -> "TextQuestion": - model = TextQuestionModel(**data) - return cls.from_model(model=model) - @property def use_markdown(self) -> bool: return self._model.settings.use_markdown @@ -268,9 +290,7 @@ def use_markdown(self, use_markdown: bool) -> None: self._model.settings.use_markdown = use_markdown -class RatingQuestion(QuestionPropertyBase): - _model: RatingQuestionModel - +class RatingQuestion(QuestionBase): def __init__( self, name: str, @@ -278,6 +298,7 @@ def __init__( title: Optional[str] = None, description: Optional[str] = None, required: bool = True, + client: Optional[Argilla] = None, ) -> None: """Create a new rating question for `Settings` of a `Dataset`. A rating question \ is a question where the user can select a value from a sequential list of options. @@ -289,39 +310,33 @@ def __init__( description (Optional[str]): The description of the question to be shown in the UI. required (bool): If the question is required for a record to be valid. At least one question must be required. """ - self._model = RatingQuestionModel( + + super().__init__( name=name, title=title, - description=description, required=required, - values=values, + description=description, settings=RatingQuestionSettings(options=self._render_values_as_options(values)), + _client=client, ) - @classmethod - def from_model(cls, model: RatingQuestionModel) -> "RatingQuestion": - instance = cls(name=model.name, values=cls._render_options_as_values(model.settings.options)) - instance._model = model - - return instance - - @classmethod - def from_dict(cls, data: dict) -> "RatingQuestion": - model = RatingQuestionModel(**data) - return cls.from_model(model=model) - @property def values(self) -> List[int]: - return self._render_options_as_labels(self._model.settings.options) + return self._render_options_as_labels(self._model.settings.options) # noqa @values.setter def values(self, values: List[int]) -> None: self._model.values = self._render_values_as_options(values) + @classmethod + def from_model(cls, model: QuestionModel) -> "Self": + instance = cls(name=model.name, values=cls._render_options_as_labels(model.settings.options)) # noqa + instance._model = model + + return instance -class RankingQuestion(QuestionPropertyBase): - _model: RankingQuestionModel +class RankingQuestion(QuestionBase): def __init__( self, name: str, @@ -329,6 +344,7 @@ def __init__( title: Optional[str] = None, description: Optional[str] = None, required: bool = True, + client: Optional[Argilla] = None, ) -> None: """Create a new ranking question for `Settings` of a `Dataset`. A ranking question \ is a question where the user can rank a list of options. @@ -341,26 +357,15 @@ def __init__( description (Optional[str]): The description of the question to be shown in the UI. required (bool): If the question is required for a record to be valid. At least one question must be required. """ - self._model = RankingQuestionModel( + super().__init__( name=name, title=title, - description=description, required=required, + description=description, settings=RankingQuestionSettings(options=self._render_values_as_options(values)), + _client=client, ) - @classmethod - def from_model(cls, model: RankingQuestionModel) -> "RankingQuestion": - instance = cls(name=model.name, values=cls._render_options_as_values(model.settings.options)) - instance._model = model - - return instance - - @classmethod - def from_dict(cls, data: dict) -> "RankingQuestion": - model = RankingQuestionModel(**data) - return cls.from_model(model=model) - @property def values(self) -> List[str]: return self._render_options_as_labels(self._model.settings.options) @@ -369,10 +374,15 @@ def values(self) -> List[str]: def values(self, values: List[int]) -> None: self._model.settings.options = self._render_values_as_options(values) + @classmethod + def from_model(cls, model: QuestionModel) -> "Self": + instance = cls(name=model.name, values=cls._render_options_as_labels(model.settings.options)) # noqa + instance._model = model + + return instance -class SpanQuestion(QuestionPropertyBase): - _model: SpanQuestionModel +class SpanQuestion(QuestionBase): def __init__( self, name: str, @@ -383,6 +393,7 @@ def __init__( title: Optional[str] = None, description: Optional[str] = None, required: bool = True, + client: Optional[Argilla] = None, ): """ Create a new span question for `Settings` of a `Dataset`. A span question \ is a question where the user can select a section of text within a text field \ @@ -400,23 +411,20 @@ def __init__( description (Optional[str]): The description of the question to be shown in the UI. required (bool): If the question is required for a record to be valid. At least one question must be required. """ - self._model = SpanQuestionModel( + super().__init__( name=name, title=title, - description=description, required=required, + description=description, settings=SpanQuestionSettings( field=field, allow_overlapping=allow_overlapping, visible_options=visible_labels, options=self._render_values_as_options(labels), ), + _client=client, ) - @property - def name(self): - return self._model.name - @property def field(self): return self._model.settings.field @@ -450,21 +458,16 @@ def labels(self, labels: List[str]) -> None: self._model.settings.options = self._render_values_as_options(labels) @classmethod - def from_model(cls, model: SpanQuestionModel) -> "SpanQuestion": + def from_model(cls, model: QuestionModel) -> "Self": instance = cls( name=model.name, field=model.settings.field, - labels=cls._render_options_as_values(model.settings.options), - ) + labels=cls._render_options_as_labels(model.settings.options), + ) # noqa instance._model = model return instance - @classmethod - def from_dict(cls, data: dict) -> "SpanQuestion": - model = SpanQuestionModel(**data) - return cls.from_model(model=model) - QuestionType = Union[ LabelQuestion, @@ -475,25 +478,25 @@ def from_dict(cls, data: dict) -> "SpanQuestion": SpanQuestion, ] -_TYPE_TO_CLASS = { - "label_selection": LabelQuestion, - "multi_label_selection": MultiLabelQuestion, - "ranking": RankingQuestion, - "text": TextQuestion, - "rating": RatingQuestion, - "span": SpanQuestion, -} - def question_from_model(model: QuestionModel) -> QuestionType: - try: - return _TYPE_TO_CLASS[model.settings.type].from_model(model) - except KeyError: - raise ValueError(f"Unsupported question model type: {model.settings.type}") - - -def question_from_dict(data: dict) -> QuestionType: - try: - return _TYPE_TO_CLASS[data["settings"]["type"]].from_dict(data) - except KeyError: - raise ValueError(f"Unsupported question model type: {data['settings']['type']}") + question_type = model.type + + if question_type == "label_selection": + return LabelQuestion.from_model(model) + elif question_type == "multi_label_selection": + return MultiLabelQuestion.from_model(model) + elif question_type == "ranking": + return RankingQuestion.from_model(model) + elif question_type == "text": + return TextQuestion.from_model(model) + elif question_type == "rating": + return RatingQuestion.from_model(model) + elif question_type == "span": + return SpanQuestion.from_model(model) + else: + raise ValueError(f"Unsupported question model type: {question_type}") + + +def _question_from_dict(data: dict) -> QuestionType: + return question_from_model(QuestionModel(**data)) diff --git a/argilla/src/argilla/settings/_resource.py b/argilla/src/argilla/settings/_resource.py index 97ced197c9..6971db722f 100644 --- a/argilla/src/argilla/settings/_resource.py +++ b/argilla/src/argilla/settings/_resource.py @@ -26,7 +26,7 @@ from argilla.settings._field import Field, _field_from_dict, _field_from_model, FieldBase from argilla.settings._io import build_settings_from_repo_id from argilla.settings._metadata import MetadataType, MetadataField, MetadataPropertyBase -from argilla.settings._question import QuestionType, question_from_model, question_from_dict, QuestionPropertyBase +from argilla.settings._question import QuestionType, question_from_model, _question_from_dict, QuestionBase from argilla.settings._task_distribution import TaskDistribution from argilla.settings._templates import DefaultSettingsMixin from argilla.settings._vector import VectorField @@ -78,7 +78,7 @@ def __init__( self.__guidelines = self.__process_guidelines(guidelines) self.__allow_extra_metadata = allow_extra_metadata - self.__questions = QuestionsProperties(self, questions) + self.__questions = SettingsProperties(self, questions) self.__fields = SettingsProperties(self, fields) self.__vectors = SettingsProperties(self, vectors) self.__metadata = SettingsProperties(self, metadata) @@ -101,7 +101,7 @@ def questions(self) -> "SettingsProperties": @questions.setter def questions(self, questions: List[QuestionType]): - self.__questions = QuestionsProperties(self, questions) + self.__questions = SettingsProperties(self, questions) @property def vectors(self) -> "SettingsProperties": @@ -220,6 +220,7 @@ def update(self) -> "Resource": self._update_dataset_related_attributes() self.__fields._update() + self.__questions._update() self.__vectors._update() self.__metadata._update() self.__questions._update() @@ -314,7 +315,7 @@ def add( if isinstance(property, FieldBase): self.fields.add(property) - elif isinstance(property, QuestionPropertyBase): + elif isinstance(property, QuestionBase): self.questions.add(property) elif isinstance(property, VectorField): self.vectors.add(property) @@ -349,7 +350,7 @@ def _from_dict(cls, settings_dict: dict) -> "Settings": allow_extra_metadata = settings_dict.get("allow_extra_metadata") mapping = settings_dict.get("mapping") - questions = [question_from_dict(question) for question in settings_dict.get("questions", [])] + questions = [_question_from_dict(question) for question in settings_dict.get("questions", [])] fields = [_field_from_dict(field) for field in fields] vectors = [VectorField.from_dict(vector) for vector in vectors] metadata = [MetadataField.from_dict(metadata) for metadata in metadata] @@ -566,34 +567,3 @@ def __repr__(self) -> str: """Return a string representation of the object.""" return f"{repr([prop for prop in self])}" - - -class QuestionsProperties(SettingsProperties[QuestionType]): - """ - This class is used to align questions with the rest of the settings. - - Since questions are not aligned with the Resource class definition, we use this - class to work with questions as we do with fields, vectors, or metadata (specially when creating questions). - - Once issue https://github.com/argilla-io/argilla/issues/4931 is tackled, this class should be removed. - """ - - def _create(self): - for question in self: - try: - self._create_question(question) - except ArgillaAPIError as e: - raise SettingsError(f"Failed to create question {question.name}") from e - - def _update(self): - pass - - def _delete(self): - pass - - def _create_question(self, question: QuestionType) -> None: - question_model = self._settings._client.api.questions.create( - dataset_id=self._settings.dataset.id, - question=question.api_model(), - ) - question._model = question_model diff --git a/argilla/tests/integration/conftest.py b/argilla/tests/integration/conftest.py index 655c98f76d..7e0850ceed 100644 --- a/argilla/tests/integration/conftest.py +++ b/argilla/tests/integration/conftest.py @@ -32,6 +32,10 @@ def client() -> rg.Argilla: def _cleanup(client: rg.Argilla): + for dataset in client.datasets: + if dataset.name.startswith("test_"): + dataset.delete() + for workspace in client.workspaces: if workspace.name.startswith("test_"): for dataset in workspace.datasets: diff --git a/argilla/tests/integration/test_add_records.py b/argilla/tests/integration/test_add_records.py index a6b87ca96c..11b9652125 100644 --- a/argilla/tests/integration/test_add_records.py +++ b/argilla/tests/integration/test_add_records.py @@ -72,7 +72,7 @@ def test_add_records(client): assert dataset_records[2].fields["text"] == mock_data[2]["text"] -def test_add_dict_records(client: Argilla): +def test_add_dict_records(client: Argilla, dataset_name: str): ws_name = "new_ws" ws = client.workspaces(ws_name) or Workspace(name=ws_name).create() @@ -80,7 +80,7 @@ def test_add_dict_records(client: Argilla): if ds is not None: ds.delete() - ds = rg.Dataset(name="new_ds", workspace=ws) + ds = rg.Dataset(name=dataset_name, workspace=ws) ds.settings = rg.Settings( fields=[rg.TextField(name="text")], questions=[rg.TextQuestion(name="label")], diff --git a/argilla/tests/integration/test_export_dataset.py b/argilla/tests/integration/test_export_dataset.py index 0a226bd1f5..dc8a719daa 100644 --- a/argilla/tests/integration/test_export_dataset.py +++ b/argilla/tests/integration/test_export_dataset.py @@ -31,8 +31,7 @@ @pytest.fixture -def dataset(client) -> rg.Dataset: - mock_dataset_name = "".join(random.choices(ascii_lowercase, k=16)) +def dataset(client, dataset_name: str) -> rg.Dataset: settings = rg.Settings( fields=[ rg.TextField(name="text"), @@ -44,7 +43,7 @@ def dataset(client) -> rg.Dataset: ], ) dataset = rg.Dataset( - name=mock_dataset_name, + name=dataset_name, settings=settings, client=client, ) diff --git a/argilla/tests/integration/test_export_records.py b/argilla/tests/integration/test_export_records.py index 0314cd8741..7b414d9f95 100644 --- a/argilla/tests/integration/test_export_records.py +++ b/argilla/tests/integration/test_export_records.py @@ -28,8 +28,7 @@ @pytest.fixture -def dataset(client) -> rg.Dataset: - mock_dataset_name = "".join(random.choices(ascii_lowercase, k=16)) +def dataset(client, dataset_name: str) -> rg.Dataset: settings = rg.Settings( fields=[ rg.TextField(name="text"), @@ -41,7 +40,7 @@ def dataset(client) -> rg.Dataset: ], ) dataset = rg.Dataset( - name=mock_dataset_name, + name=dataset_name, settings=settings, client=client, ) diff --git a/argilla/tests/integration/test_import_features.py b/argilla/tests/integration/test_import_features.py index 6c1f530661..1f85213c7e 100644 --- a/argilla/tests/integration/test_import_features.py +++ b/argilla/tests/integration/test_import_features.py @@ -30,8 +30,7 @@ @pytest.fixture -def dataset(client) -> rg.Dataset: - mock_dataset_name = "".join(random.choices(ascii_lowercase, k=16)) +def dataset(client, dataset_name: str) -> rg.Dataset: settings = rg.Settings( fields=[ rg.TextField(name="text"), @@ -42,7 +41,7 @@ def dataset(client) -> rg.Dataset: ], ) dataset = rg.Dataset( - name=mock_dataset_name, + name=dataset_name, settings=settings, client=client, ) diff --git a/argilla/tests/integration/test_metadata.py b/argilla/tests/integration/test_manage_metadata.py similarity index 88% rename from argilla/tests/integration/test_metadata.py rename to argilla/tests/integration/test_manage_metadata.py index 2aa9d7c2f2..1acaa65035 100644 --- a/argilla/tests/integration/test_metadata.py +++ b/argilla/tests/integration/test_manage_metadata.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random -from string import ascii_lowercase - import pytest import argilla as rg @@ -22,8 +19,7 @@ @pytest.fixture -def dataset_with_metadata(client: Argilla, workspace: Workspace) -> Dataset: - name = "".join(random.choices(ascii_lowercase, k=16)) +def dataset_with_metadata(client: Argilla, workspace: Workspace, dataset_name: str) -> Dataset: settings = Settings( fields=[TextField(name="text")], questions=[LabelQuestion(name="label", labels=["positive", "negative"])], @@ -32,7 +28,7 @@ def dataset_with_metadata(client: Argilla, workspace: Workspace) -> Dataset: ], ) dataset = Dataset( - name=name, + name=dataset_name, workspace=workspace.name, settings=settings, client=client, @@ -42,8 +38,7 @@ def dataset_with_metadata(client: Argilla, workspace: Workspace) -> Dataset: return dataset -def test_create_dataset_with_metadata(client: Argilla, workspace: Workspace) -> Dataset: - name = "".join(random.choices(ascii_lowercase, k=16)) +def test_create_dataset_with_metadata(client: Argilla, workspace: Workspace, dataset_name: str) -> None: settings = Settings( fields=[TextField(name="text")], questions=[LabelQuestion(name="label", labels=["positive", "negative"])], @@ -52,7 +47,7 @@ def test_create_dataset_with_metadata(client: Argilla, workspace: Workspace) -> ], ) dataset = Dataset( - name=name, + name=dataset_name, workspace=workspace.name, settings=settings, client=client, @@ -72,8 +67,9 @@ def test_create_dataset_with_metadata(client: Argilla, workspace: Workspace) -> (None, None, rg.IntegerMetadataProperty), ], ) -def test_create_dataset_with_numerical_metadata(client: Argilla, workspace: Workspace, min, max, type) -> Dataset: - name = "".join(random.choices(ascii_lowercase, k=16)) +def test_create_dataset_with_numerical_metadata( + client: Argilla, workspace: Workspace, dataset_name: str, min, max, type +) -> None: settings = Settings( fields=[TextField(name="text")], questions=[LabelQuestion(name="label", labels=["positive", "negative"])], @@ -82,7 +78,7 @@ def test_create_dataset_with_numerical_metadata(client: Argilla, workspace: Work ], ) dataset = Dataset( - name=name, + name=dataset_name, workspace=workspace.name, settings=settings, client=client, diff --git a/argilla/tests/integration/test_publish_datasets.py b/argilla/tests/integration/test_publish_datasets.py index 057a08d646..9ed8245509 100644 --- a/argilla/tests/integration/test_publish_datasets.py +++ b/argilla/tests/integration/test_publish_datasets.py @@ -31,19 +31,18 @@ ) -def test_publish_dataset(client: "Argilla"): +def test_publish_dataset(client: "Argilla", dataset_name: str): ws_name = "new_ws" - ds_name = "new_ds" new_ws = client.workspaces(ws_name) or Workspace(name=ws_name).create() assert client.api.workspaces.exists(new_ws.id), "The workspace was not created" - ds = client.datasets(ds_name, workspace=new_ws) + ds = client.datasets(dataset_name, workspace=new_ws) if ds: ds.delete() assert not client.api.datasets.exists(ds.id), "The dataset was not deleted" - ds = Dataset(name=ds_name, workspace=new_ws) + ds = Dataset(name=dataset_name, workspace=new_ws) ds.settings = Settings( guidelines="This is a test dataset", diff --git a/argilla/tests/integration/test_update_dataset_settings.py b/argilla/tests/integration/test_update_dataset_settings.py index 5ec1883fba..16865f5fe8 100644 --- a/argilla/tests/integration/test_update_dataset_settings.py +++ b/argilla/tests/integration/test_update_dataset_settings.py @@ -64,6 +64,15 @@ def test_update_settings(self, client: Argilla, dataset: Dataset): dataset = client.datasets(dataset.name) assert dataset.settings.vectors["vector"].title == "A new title for vector" + def test_update_question_title(self, client: Argilla, dataset: Dataset): + question = dataset.settings.questions["label"] + question.title = "A new title for label question" + dataset.settings.update() + + dataset = client.datasets(dataset.name) + question = dataset.settings.questions["label"] + assert question.title == "A new title for label question" + def test_update_distribution_settings(self, client: Argilla, dataset: Dataset): dataset.settings.distribution.min_submitted = 100 dataset.update() diff --git a/argilla/tests/integration/test_update_records.py b/argilla/tests/integration/test_update_records.py index 1dc60c85fa..3690a3cd54 100644 --- a/argilla/tests/integration/test_update_records.py +++ b/argilla/tests/integration/test_update_records.py @@ -24,9 +24,8 @@ @pytest.fixture -def dataset(client: rg.Argilla) -> rg.Dataset: +def dataset(client: rg.Argilla, dataset_name: str) -> rg.Dataset: workspace = client.workspaces[0] - mock_dataset_name = "".join(random.choices(ascii_lowercase, k=16)) settings = rg.Settings( allow_extra_metadata=True, fields=[ @@ -37,7 +36,7 @@ def dataset(client: rg.Argilla) -> rg.Dataset: ], ) dataset = rg.Dataset( - name=mock_dataset_name, + name=dataset_name, workspace=workspace.name, settings=settings, client=client, diff --git a/argilla/tests/unit/test_resources/test_questions.py b/argilla/tests/unit/test_resources/test_questions.py index f4bd1ecec7..ab5cef3a25 100644 --- a/argilla/tests/unit/test_resources/test_questions.py +++ b/argilla/tests/unit/test_resources/test_questions.py @@ -19,76 +19,10 @@ from pytest_httpx import HTTPXMock import argilla as rg -from argilla._models import TextQuestionModel, LabelQuestionModel -from argilla._models._settings._questions import SpanQuestionModel +from argilla._models import QuestionModel class TestQuestionsAPI: - def test_create_many_questions(self, httpx_mock: HTTPXMock): - # TODO: Add a test for the delete method in client - mock_dataset_id = uuid.uuid4() - mock_return_value = { - "id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", - "name": "string", - "title": "string", - "required": True, - "settings": {"type": "text", "use_markdown": False}, - "dataset_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", - "inserted_at": datetime.now().isoformat(), - "updated_at": datetime.now().isoformat(), - } - mock_question = { - "name": "5044cv0wu5", - "title": "string", - "description": "string", - "required": True, - "settings": {"type": "text", "use_markdown": False}, - } - mock_question = TextQuestionModel(**mock_question) - httpx_mock.add_response( - json=mock_return_value, - url=f"http://test_url/api/v1/datasets/{mock_dataset_id}/questions", - method="POST", - status_code=200, - ) - with httpx.Client() as client: - client = rg.Argilla(api_url="http://test_url") - client.api.questions.create_many(dataset_id=mock_dataset_id, questions=[mock_question]) - - def test_create_many_label_questions(self, httpx_mock: HTTPXMock): - # TODO: Add a test for the delete method in client - mock_dataset_id = uuid.uuid4() - mock_return_value = { - "id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", - "name": "string", - "title": "string", - "required": True, - "settings": {"type": "labels", "options": [{"text": "positive", "value": "positive"}]}, - "dataset_id": "3fa85f64-5717-4562-b3fc-2c963f66afa6", - "inserted_at": datetime.now().isoformat(), - "updated_at": datetime.now().isoformat(), - } - mock_question = { - "name": "5044cv0wu5", - "title": "string", - "description": "string", - "required": True, - "settings": { - "type": "label", - "options": [{"text": "negative", "value": "negative"}, {"text": "positive", "value": "positive"}], - }, - } - mock_question = LabelQuestionModel(**mock_question) - httpx_mock.add_response( - json=mock_return_value, - url=f"http://test_url/api/v1/datasets/{mock_dataset_id}/questions", - method="POST", - status_code=200, - ) - with httpx.Client() as client: - client = rg.Argilla(api_url="http://test_url") - client.api.questions.create_many(dataset_id=mock_dataset_id, questions=[mock_question]) - def test_create_span_question(self, httpx_mock: HTTPXMock): mock_dataset_id = uuid.uuid4() mock_return_value = { @@ -96,6 +30,7 @@ def test_create_span_question(self, httpx_mock: HTTPXMock): "name": "string", "title": "string", "required": True, + "dataset_id": str(mock_dataset_id), "settings": { "type": "span", "allow_overlapping": True, @@ -119,11 +54,12 @@ def test_create_span_question(self, httpx_mock: HTTPXMock): ) with httpx.Client() as _: - question = SpanQuestionModel( + question = QuestionModel( name="5044cv0wu5", title="string", description="string", required=True, + dataset_id=mock_dataset_id, settings={ "type": "span", "allow_overlapping": True, @@ -138,5 +74,5 @@ def test_create_span_question(self, httpx_mock: HTTPXMock): ) client = rg.Argilla(api_url="http://test_url") - created_question = client.api.questions.create(dataset_id=mock_dataset_id, question=question) + created_question = client.api.questions.create(question=question) assert created_question.model_dump(exclude_unset=True) == mock_return_value