diff --git a/argilla/src/argilla/records/_io/_generic.py b/argilla/src/argilla/records/_io/_generic.py index 518181f85d..878c7b81d8 100644 --- a/argilla/src/argilla/records/_io/_generic.py +++ b/argilla/src/argilla/records/_io/_generic.py @@ -84,26 +84,39 @@ def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]: Returns: A dictionary representing the record. """ + record_dict = record.to_dict() if flatten: - responses: dict = record_dict.pop("responses") - suggestions: dict = record_dict.pop("suggestions") - fields: dict = record_dict.pop("fields") - metadata: dict = record_dict.pop("metadata") - record_dict.update(fields) - record_dict.update(metadata) - question_names = set(suggestions.keys()).union(responses.keys()) - for question_name in question_names: - _suggestion: Union[Dict, None] = suggestions.get(question_name) - if _suggestion: - record_dict[f"{question_name}.suggestion"] = _suggestion.pop("value") - record_dict.update( - {f"{question_name}.suggestion.{key}": value for key, value in _suggestion.items()} - ) - for _response in responses.get(question_name, []): - user_id = _response.pop("user_id") - record_dict[f"{question_name}.response.{user_id}"] = _response.pop("value") - record_dict.update( - {f"{question_name}.response.{user_id}.{key}": value for key, value in _response.items()} - ) + record_dict.update( + **record_dict.pop("fields", {}), + **record_dict.pop("metadata", {}), + **record_dict.pop("vectors", {}), + ) + + record_dict.pop("responses") + record_dict.pop("suggestions") + + responses_dict = defaultdict(list) + for response in record.responses: + responses_key = f"{response.question_name}.responses" + responses_users_key = f"{responses_key}.users" + + responses_dict[responses_key].append(response.value) + responses_dict[responses_users_key].append(str(response.user_id)) + + suggestions_dict = {} + for suggestion in record.suggestions: + suggestion_key = f"{suggestion.question_name}.suggestion" + suggestion_agent_key = f"{suggestion_key}.agent" + suggestion_score_key = f"{suggestion_key}.score" + + suggestions_dict.update( + { + suggestion_key: suggestion.value, + suggestion_score_key: suggestion.score, + suggestion_agent_key: suggestion.agent, + } + ) + + record_dict.update({**responses_dict, **suggestions_dict}) return record_dict diff --git a/argilla/tests/unit/test_io/test_generic.py b/argilla/tests/unit/test_io/test_generic.py new file mode 100644 index 0000000000..446693f5b5 --- /dev/null +++ b/argilla/tests/unit/test_io/test_generic.py @@ -0,0 +1,58 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import uuid4 + +import argilla as rg +from argilla.records._io import GenericIO + + +class TestGenericIO: + def test_to_list_flatten(self): + user_a, user_b, user_c = uuid4(), uuid4(), uuid4() + + record = rg.Record( + fields={"field": "The field"}, + metadata={"key": "value"}, + responses=[ + rg.Response(question_name="q1", value="value", user_id=user_a), + rg.Response(question_name="q2", value="value", user_id=user_a), + rg.Response(question_name="q2", value="value", user_id=user_b), + rg.Response(question_name="q1", value="value", user_id=user_c), + ], + suggestions=[ + rg.Suggestion(question_name="q1", value="value", score=0.1, agent="test"), + rg.Suggestion(question_name="q2", value="value", score=0.9), + ], + ) + + records_list = GenericIO.to_list([record], flatten=True) + assert records_list == [ + { + "id": str(record.id), + "_server_id": None, + "field": "The field", + "key": "value", + "q1.responses": ["value", "value"], + "q1.responses.users": [str(user_a), str(user_c)], + "q2.responses": ["value", "value"], + "q2.responses.users": [str(user_a), str(user_b)], + "q1.suggestion": "value", + "q1.suggestion.score": 0.1, + "q1.suggestion.agent": "test", + "q2.suggestion": "value", + "q2.suggestion.score": 0.9, + "q2.suggestion.agent": None, + } + ]