diff --git a/argilla/src/argilla/_exceptions/_api.py b/argilla/src/argilla/_exceptions/_api.py index 7ab82a90c5..b7809ff91b 100644 --- a/argilla/src/argilla/_exceptions/_api.py +++ b/argilla/src/argilla/_exceptions/_api.py @@ -11,14 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from httpx import HTTPStatusError -from argilla._exceptions._base import ArgillaErrorBase +from argilla._exceptions._base import ArgillaError -class ArgillaAPIError(ArgillaErrorBase): - pass +class ArgillaAPIError(ArgillaError): + def __init__(self, message: Optional[str] = None, status_code: int = 500): + """Base class for all Argilla API exceptions + Args: + message (str): The message to display when the exception is raised + status_code (int): The status code of the response that caused the exception + """ + super().__init__(message) + self.status_code = status_code class BadRequestError(ArgillaAPIError): diff --git a/argilla/src/argilla/_exceptions/_base.py b/argilla/src/argilla/_exceptions/_base.py index 8d08488d4a..57d0714071 100644 --- a/argilla/src/argilla/_exceptions/_base.py +++ b/argilla/src/argilla/_exceptions/_base.py @@ -11,20 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional -class ArgillaErrorBase(Exception): +class ArgillaError(Exception): message_stub = "Argilla SDK error" - message: str = message_stub - def __init__(self, message: str = message, status_code: int = 500): + def __init__(self, message: Optional[str] = None): """Base class for all Argilla exceptions Args: message (str): The message to display when the exception is raised - status_code (int): The status code of the response that caused the exception """ - super().__init__(message) - self.status_code = status_code + super().__init__(message or self.message_stub) def __str__(self): return f"{self.message_stub}: {self.__class__.__name__}: {super().__str__()}" diff --git a/argilla/src/argilla/_exceptions/_metadata.py b/argilla/src/argilla/_exceptions/_metadata.py index 504b8cc4d1..0b55afd43e 100644 --- a/argilla/src/argilla/_exceptions/_metadata.py +++ b/argilla/src/argilla/_exceptions/_metadata.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla._exceptions._base import ArgillaErrorBase +from argilla._exceptions._base import ArgillaError -class MetadataError(ArgillaErrorBase): +class MetadataError(ArgillaError): message: str = "Error defining dataset metadata settings" diff --git a/argilla/src/argilla/_exceptions/_records.py b/argilla/src/argilla/_exceptions/_records.py index ec7f045b82..21f1501ed5 100644 --- a/argilla/src/argilla/_exceptions/_records.py +++ b/argilla/src/argilla/_exceptions/_records.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla._exceptions._base import ArgillaErrorBase +from argilla._exceptions._base import ArgillaError -class RecordsIngestionError(ArgillaErrorBase): +class RecordsIngestionError(ArgillaError): pass diff --git a/argilla/src/argilla/_exceptions/_serialization.py b/argilla/src/argilla/_exceptions/_serialization.py index e81bbfdcd9..75fa95803a 100644 --- a/argilla/src/argilla/_exceptions/_serialization.py +++ b/argilla/src/argilla/_exceptions/_serialization.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla._exceptions._base import ArgillaErrorBase +from argilla._exceptions._base import ArgillaError -class ArgillaSerializeError(ArgillaErrorBase): +class ArgillaSerializeError(ArgillaError): pass diff --git a/argilla/src/argilla/_exceptions/_settings.py b/argilla/src/argilla/_exceptions/_settings.py index 0266b76468..8e8164fede 100644 --- a/argilla/src/argilla/_exceptions/_settings.py +++ b/argilla/src/argilla/_exceptions/_settings.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla._exceptions._base import ArgillaErrorBase +from argilla._exceptions._base import ArgillaError -class SettingsError(ArgillaErrorBase): +class SettingsError(ArgillaError): pass diff --git a/argilla/src/argilla/records/_resource.py b/argilla/src/argilla/records/_resource.py index 53c1321b4b..130d4712ac 100644 --- a/argilla/src/argilla/records/_resource.py +++ b/argilla/src/argilla/records/_resource.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Iterable from uuid import UUID, uuid4 +from argilla._exceptions import ArgillaError from argilla._models import ( MetadataModel, RecordModel, @@ -324,6 +325,9 @@ def __iter__(self): def __getitem__(self, name: str): return self.__responses_by_question_name[name] + def __len__(self): + return len(self.__responses) + def __repr__(self) -> str: return {k: [{"value": v["value"]} for v in values] for k, values in self.to_dict().items()}.__repr__() @@ -354,10 +358,21 @@ def add(self, response: Response) -> None: Args: response: The response to add. """ + self._check_response_already_exists(response) + response.record = self.record self.__responses.append(response) self.__responses_by_question_name[response.question_name].append(response) + def _check_response_already_exists(self, response: Response) -> None: + """Checks if a response for the same question name and user id already exists""" + for response in self.__responses_by_question_name[response.question_name]: + if response.user_id == response.user_id: + raise ArgillaError( + f"Response for question with name {response.question_name!r} and user id {response.user_id!r} " + f"already found. The responses for the same question name do not support more than one user" + ) + class RecordSuggestions(Iterable[Suggestion]): """This is a container class for the suggestions of a Record. diff --git a/argilla/tests/unit/test_resources/test_records.py b/argilla/tests/unit/test_resources/test_records.py index 09759430c7..a4b87a6374 100644 --- a/argilla/tests/unit/test_resources/test_records.py +++ b/argilla/tests/unit/test_resources/test_records.py @@ -14,7 +14,10 @@ import uuid +import pytest + from argilla import Record, Suggestion, Response +from argilla._exceptions import ArgillaError from argilla._models import MetadataModel @@ -62,3 +65,10 @@ def test_update_record_vectors(self): record.vectors["new-vector"] = [1.0, 2.0, 3.0] assert record.vectors == {"vector": [1.0, 2.0, 3.0], "new-vector": [1.0, 2.0, 3.0]} + + def test_add_record_response_for_the_same_question_and_user_id(self): + response = Response(question_name="question", value="value", user_id=uuid.uuid4()) + record = Record(fields={"name": "John"}, responses=[response]) + + with pytest.raises(ArgillaError): + record.responses.add(response)