Skip to content

Commit

Permalink
5272 bug pythondeployment copying records with suggestions produces a…
Browse files Browse the repository at this point in the history
…n unprocessableentityerror (#5282)

# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR fixes errors when creating records with suggestions from other
datasets. Also, remove the internal `question_id` and `id` from the
suggestion __init__ method.

Closes #5272 

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)


**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jul 22, 2024
1 parent 0e88ede commit 01f8e2e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 28 deletions.
5 changes: 2 additions & 3 deletions argilla/src/argilla/records/_mapping/_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,10 @@ def _map_suggestions(self, data: Dict[str, Any], mapping) -> List[Suggestion]:
parameters = {param.parameter_type: data.get(param.source) for param in route.parameters}
if parameters.get(ParameterType.VALUE) is None:
continue
schema_item = self._dataset.schema.get(name)
question = self._dataset.questions[name]
suggestion = Suggestion(
**parameters,
question_name=route.name,
question_id=schema_item.id,
question_name=question.name,
)
suggestions.append(suggestion)

Expand Down
29 changes: 8 additions & 21 deletions argilla/src/argilla/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Literal, Union, List, TYPE_CHECKING, Dict
from uuid import UUID

from argilla._models import SuggestionModel
from argilla._resource import Resource
Expand All @@ -29,13 +28,11 @@ class Suggestion(Resource):
Suggestions are rendered in the user interfaces as 'hints' or 'suggestions' for the user to review and accept or reject.
Attributes:
value (str): The value of the suggestion.add()
question_name (str): The name of the question that the suggestion is for.
type (str): The type of suggestion, either 'model' or 'human'.
value (str): The value of the suggestion
score (float): The score of the suggestion. For example, the probability of the model prediction.
agent (str): The agent that created the suggestion. For example, the model name.
question_id (UUID): The ID of the question that the suggestion is for.
type (str): The type of suggestion, either 'model' or 'human'.
"""

_model: SuggestionModel
Expand All @@ -47,8 +44,6 @@ def __init__(
score: Union[float, List[float], None] = None,
agent: Optional[str] = None,
type: Optional[Literal["model", "human"]] = None,
id: Optional[UUID] = None,
question_id: Optional[UUID] = None,
_record: Optional["Record"] = None,
) -> None:
super().__init__()
Expand All @@ -60,9 +55,7 @@ def __init__(

self.record = _record
self._model = SuggestionModel(
id=id,
question_name=question_name,
question_id=question_id,
value=value,
type=type,
score=score,
Expand All @@ -87,15 +80,6 @@ def question_name(self) -> Optional[str]:
def question_name(self, value: str) -> None:
self._model.question_name = value

@property
def question_id(self) -> Optional[UUID]:
"""The ID of the question that the suggestion is for."""
return self._model.question_id

@question_id.setter
def question_id(self, value: UUID) -> None:
self._model.question_id = value

@property
def type(self) -> Optional[Literal["model", "human"]]:
"""The type of suggestion, either 'model' or 'human'."""
Expand Down Expand Up @@ -125,7 +109,10 @@ def from_model(cls, model: SuggestionModel, dataset: "Dataset") -> "Suggestion":
model.question_name = question.name
model.value = cls.__from_model_value(model.value, question)

return cls(**model.model_dump())
instance = cls(question.name, model.value)
instance._model = model

return instance

def api_model(self) -> SuggestionModel:
if self.record is None or self.record.dataset is None:
Expand All @@ -134,8 +121,8 @@ def api_model(self) -> SuggestionModel:
question = self.record.dataset.settings.questions[self.question_name]
return SuggestionModel(
value=self.__to_model_value(self.value, question),
question_name=self.question_name,
question_id=self.question_id or question.id,
question_name=question.name,
question_id=question.id,
type=self._model.type,
score=self._model.score,
agent=self._model.agent,
Expand Down
39 changes: 35 additions & 4 deletions argilla/tests/integration/test_create_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@

import pytest

from argilla import Argilla, Dataset, Settings, TextField, RatingQuestion, LabelQuestion, Workspace
from argilla import (
Argilla,
Dataset,
Settings,
TextField,
RatingQuestion,
LabelQuestion,
Workspace,
VectorField,
TermsMetadataProperty,
)
from argilla.settings._task_distribution import TaskDistribution


Expand Down Expand Up @@ -127,20 +137,41 @@ def test_create_a_dataset_copy(self, client: Argilla, dataset_name: str):
settings=Settings(
fields=[TextField(name="text")],
questions=[RatingQuestion(name="question", values=[1, 2, 3, 4, 5])],
vectors=[VectorField(name="vector", dimensions=2)],
metadata=[TermsMetadataProperty(name="terms")],
),
).create()

dataset.records.log([{"text": "This is a text"}])
dataset.records.log(
[
{
"text": "This is a text",
"terms": ["a", "b"],
"vector": [1, 2],
"question": 3,
}
]
)

new_dataset = Dataset(
name=f"{dataset_name}_copy",
settings=dataset.settings,
).create()

records = list(dataset.records)
records = list(dataset.records(with_vectors=True))
new_dataset.records.log(records)

assert len(list(dataset.records)) == len(list(new_dataset.records))
expected_records = list(dataset.records(with_vectors=True))
records = list(new_dataset.records(with_vectors=True))
assert len(expected_records) == len(records)
assert len(records) == 1

record, expected_record = records[0], expected_records[0]

assert expected_record.metadata.to_dict() == record.metadata.to_dict()
assert expected_record.vectors.to_dict() == record.vectors.to_dict()
assert expected_record.suggestions.to_dict() == record.suggestions.to_dict()

assert dataset.distribution == new_dataset.distribution

def test_create_dataset_with_custom_task_distribution(self, client: Argilla, dataset_name: str):
Expand Down

0 comments on commit 01f8e2e

Please sign in to comment.