Skip to content

Commit

Permalink
[BUGFIX] argilla: raise error adding record responses when a respon…
Browse files Browse the repository at this point in the history
…se with same question_name and user_id is found (#5209)

# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR prevents adding multiple question responses per user which
result in a server error.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jul 12, 2024
1 parent 5cdd7f5 commit 7bd8ce3
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 17 deletions.
14 changes: 11 additions & 3 deletions argilla/src/argilla/_exceptions/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from httpx import HTTPStatusError

from argilla._exceptions._base import ArgillaErrorBase
from argilla._exceptions._base import ArgillaError


class ArgillaAPIError(ArgillaErrorBase):
pass
class ArgillaAPIError(ArgillaError):
def __init__(self, message: Optional[str] = None, status_code: int = 500):
"""Base class for all Argilla API exceptions
Args:
message (str): The message to display when the exception is raised
status_code (int): The status code of the response that caused the exception
"""
super().__init__(message)
self.status_code = status_code


class BadRequestError(ArgillaAPIError):
Expand Down
10 changes: 4 additions & 6 deletions argilla/src/argilla/_exceptions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional


class ArgillaErrorBase(Exception):
class ArgillaError(Exception):
message_stub = "Argilla SDK error"
message: str = message_stub

def __init__(self, message: str = message, status_code: int = 500):
def __init__(self, message: Optional[str] = None):
"""Base class for all Argilla exceptions
Args:
message (str): The message to display when the exception is raised
status_code (int): The status code of the response that caused the exception
"""
super().__init__(message)
self.status_code = status_code
super().__init__(message or self.message_stub)

def __str__(self):
return f"{self.message_stub}: {self.__class__.__name__}: {super().__str__()}"
Expand Down
4 changes: 2 additions & 2 deletions argilla/src/argilla/_exceptions/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla._exceptions._base import ArgillaErrorBase
from argilla._exceptions._base import ArgillaError


class MetadataError(ArgillaErrorBase):
class MetadataError(ArgillaError):
message: str = "Error defining dataset metadata settings"
4 changes: 2 additions & 2 deletions argilla/src/argilla/_exceptions/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla._exceptions._base import ArgillaErrorBase
from argilla._exceptions._base import ArgillaError


class RecordsIngestionError(ArgillaErrorBase):
class RecordsIngestionError(ArgillaError):
pass
4 changes: 2 additions & 2 deletions argilla/src/argilla/_exceptions/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla._exceptions._base import ArgillaErrorBase
from argilla._exceptions._base import ArgillaError


class ArgillaSerializeError(ArgillaErrorBase):
class ArgillaSerializeError(ArgillaError):
pass
4 changes: 2 additions & 2 deletions argilla/src/argilla/_exceptions/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla._exceptions._base import ArgillaErrorBase
from argilla._exceptions._base import ArgillaError


class SettingsError(ArgillaErrorBase):
class SettingsError(ArgillaError):
pass
15 changes: 15 additions & 0 deletions argilla/src/argilla/records/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Iterable
from uuid import UUID, uuid4

from argilla._exceptions import ArgillaError
from argilla._models import (
MetadataModel,
RecordModel,
Expand Down Expand Up @@ -324,6 +325,9 @@ def __iter__(self):
def __getitem__(self, name: str):
return self.__responses_by_question_name[name]

def __len__(self):
return len(self.__responses)

def __repr__(self) -> str:
return {k: [{"value": v["value"]} for v in values] for k, values in self.to_dict().items()}.__repr__()

Expand Down Expand Up @@ -354,10 +358,21 @@ def add(self, response: Response) -> None:
Args:
response: The response to add.
"""
self._check_response_already_exists(response)

response.record = self.record
self.__responses.append(response)
self.__responses_by_question_name[response.question_name].append(response)

def _check_response_already_exists(self, response: Response) -> None:
"""Checks if a response for the same question name and user id already exists"""
for response in self.__responses_by_question_name[response.question_name]:
if response.user_id == response.user_id:
raise ArgillaError(
f"Response for question with name {response.question_name!r} and user id {response.user_id!r} "
f"already found. The responses for the same question name do not support more than one user"
)


class RecordSuggestions(Iterable[Suggestion]):
"""This is a container class for the suggestions of a Record.
Expand Down
10 changes: 10 additions & 0 deletions argilla/tests/unit/test_resources/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import uuid

import pytest

from argilla import Record, Suggestion, Response
from argilla._exceptions import ArgillaError
from argilla._models import MetadataModel


Expand Down Expand Up @@ -62,3 +65,10 @@ def test_update_record_vectors(self):

record.vectors["new-vector"] = [1.0, 2.0, 3.0]
assert record.vectors == {"vector": [1.0, 2.0, 3.0], "new-vector": [1.0, 2.0, 3.0]}

def test_add_record_response_for_the_same_question_and_user_id(self):
response = Response(question_name="question", value="value", user_id=uuid.uuid4())
record = Record(fields={"name": "John"}, responses=[response])

with pytest.raises(ArgillaError):
record.responses.add(response)

0 comments on commit 7bd8ce3

Please sign in to comment.