diff --git a/argilla/src/argilla/_models/__init__.py b/argilla/src/argilla/_models/__init__.py index 0e3f21ded0..4302e4259a 100644 --- a/argilla/src/argilla/_models/__init__.py +++ b/argilla/src/argilla/_models/__init__.py @@ -22,7 +22,6 @@ from argilla._models._record._suggestion import SuggestionModel from argilla._models._record._response import UserResponseModel, ResponseStatus from argilla._models._record._vector import VectorModel, VectorValue -from argilla._models._record._metadata import MetadataModel, MetadataValue from argilla._models._search import ( SearchQueryModel, AndFilterModel, diff --git a/argilla/src/argilla/_models/_record/_metadata.py b/argilla/src/argilla/_models/_record/_metadata.py index c8124894b6..28318f6662 100644 --- a/argilla/src/argilla/_models/_record/_metadata.py +++ b/argilla/src/argilla/_models/_record/_metadata.py @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Annotated, Union, List +from typing import Annotated, Any from pydantic import BaseModel -MetadataValue = Annotated[Union[str, List[str], float, int, None], "The value of the metadata field dictionary"] - - class MetadataModel(BaseModel): """Schema for the metadata of a `Dataset`""" name: Annotated[str, "The name of the metadata field or key in the metadata dictionary"] - value: MetadataValue + value: Any diff --git a/argilla/src/argilla/_models/_record/_record.py b/argilla/src/argilla/_models/_record/_record.py index b4da821cba..e171b9ddd1 100644 --- a/argilla/src/argilla/_models/_record/_record.py +++ b/argilla/src/argilla/_models/_record/_record.py @@ -18,7 +18,7 @@ from pydantic import BaseModel, Field, field_serializer, field_validator -from argilla._models._record._metadata import MetadataModel, MetadataValue +from argilla._models._record._metadata import MetadataModel from argilla._models._record._response import UserResponseModel from argilla._models._record._suggestion import SuggestionModel from argilla._models._record._vector import VectorModel @@ -44,7 +44,7 @@ class RecordModel(ResourceModel): status: Literal["pending", "completed"] = "pending" fields: Optional[Dict[str, FieldValue]] = None - metadata: Optional[Union[List[MetadataModel], Dict[str, MetadataValue]]] = Field(default_factory=dict) + metadata: Optional[Union[List[MetadataModel], Dict[str, Any]]] = Field(default_factory=dict) vectors: Optional[List[VectorModel]] = Field(default_factory=list) responses: Optional[List[UserResponseModel]] = Field(default_factory=list) suggestions: Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]] = Field(default_factory=tuple) diff --git a/argilla/src/argilla/records/_resource.py b/argilla/src/argilla/records/_resource.py index 8de89b11c1..c1791ecc94 100644 --- a/argilla/src/argilla/records/_resource.py +++ b/argilla/src/argilla/records/_resource.py @@ -20,14 +20,13 @@ from argilla._helpers._media import cast_image, uncast_image from argilla._models import ( FieldValue, - MetadataModel, - MetadataValue, RecordModel, SuggestionModel, UserResponseModel, VectorModel, VectorValue, ) +from argilla._models._record._metadata import MetadataModel from argilla._resource import Resource from argilla.responses import Response, UserResponse from argilla.suggestions import Suggestion @@ -61,7 +60,7 @@ def __init__( self, id: Optional[Union[UUID, str]] = None, fields: Optional[Dict[str, FieldValue]] = None, - metadata: Optional[Dict[str, MetadataValue]] = None, + metadata: Optional[Dict[str, Any]] = None, vectors: Optional[Dict[str, VectorValue]] = None, responses: Optional[List[Response]] = None, suggestions: Optional[List[Suggestion]] = None, @@ -331,7 +330,7 @@ def _is_chat(self, key: str) -> bool: class RecordMetadata(dict): """This is a container class for the metadata of a Record.""" - def __init__(self, metadata: Optional[Dict[str, MetadataValue]] = None) -> None: + def __init__(self, metadata: Optional[Dict[str, Any]] = None) -> None: super().__init__(metadata or {}) def to_dict(self) -> dict: diff --git a/argilla/tests/unit/test_resources/test_records.py b/argilla/tests/unit/test_resources/test_records.py index e67f0e1531..8f2ff8d58f 100644 --- a/argilla/tests/unit/test_resources/test_records.py +++ b/argilla/tests/unit/test_resources/test_records.py @@ -18,7 +18,8 @@ from argilla import Dataset, Record, Response, Settings, Suggestion, TextField, TextQuestion from argilla._exceptions import ArgillaError -from argilla._models import MetadataModel, RecordModel +from argilla._models import RecordModel +from argilla._models._record._metadata import MetadataModel @pytest.fixture()