Skip to content

Commit

Permalink
[ENHANCEMENT] argilla: simplify structure for flatten records to li…
Browse files Browse the repository at this point in the history
…st (#5137)

This PR changes the structure generated by `to_list(flatten=True)` to
simplify reading responses. The response content is split into values
and users, so no user ID is defined as part of the column name:

The result for the following record:

```python

record = rg.Record(
    fields={"field": "The field"},
    metadata={"key": "value"},
    responses=[
        rg.Response(question_name="q1", value="value", user_id=user_a),
        rg.Response(question_name="q2", value="value", user_id=user_a),
        rg.Response(question_name="q2", value="value", user_id=user_b),
        rg.Response(question_name="q1", value="value", user_id=user_c),
    ],
    suggestions=[
        rg.Suggestion(question_name="q1", value="value", score=0.1, agent="test"),
        rg.Suggestion(question_name="q2", value="value", score=0.9),
    ],
)
```
is :
```python
{
    "id": <record_id>,
    "_server_id": None,
    "field": "The field",
    "key": "value",
    "q1.responses": ["value", "value"],
    "q1.responses.users": [str(user_a), str(user_c)],
    "q2.responses": ["value", "value"],
    "q2.responses.users": [str(user_a), str(user_b)],
    "q1.suggestion": "value",
    "q1.suggestion.score": 0.1,
    "q1.suggestion.agent": "test",
    "q2.suggestion": "value",
    "q2.suggestion.score": 0.9,
    "q2.suggestion.agent": None,
}
```

Refs #4936

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- follows the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: burtenshaw <[email protected]>
  • Loading branch information
frascuchon and burtenshaw authored Jul 3, 2024
1 parent b78b61b commit 120160d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 20 deletions.
53 changes: 33 additions & 20 deletions argilla/src/argilla/records/_io/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,39 @@ def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]:
Returns:
A dictionary representing the record.
"""

record_dict = record.to_dict()
if flatten:
responses: dict = record_dict.pop("responses")
suggestions: dict = record_dict.pop("suggestions")
fields: dict = record_dict.pop("fields")
metadata: dict = record_dict.pop("metadata")
record_dict.update(fields)
record_dict.update(metadata)
question_names = set(suggestions.keys()).union(responses.keys())
for question_name in question_names:
_suggestion: Union[Dict, None] = suggestions.get(question_name)
if _suggestion:
record_dict[f"{question_name}.suggestion"] = _suggestion.pop("value")
record_dict.update(
{f"{question_name}.suggestion.{key}": value for key, value in _suggestion.items()}
)
for _response in responses.get(question_name, []):
user_id = _response.pop("user_id")
record_dict[f"{question_name}.response.{user_id}"] = _response.pop("value")
record_dict.update(
{f"{question_name}.response.{user_id}.{key}": value for key, value in _response.items()}
)
record_dict.update(
**record_dict.pop("fields", {}),
**record_dict.pop("metadata", {}),
**record_dict.pop("vectors", {}),
)

record_dict.pop("responses")
record_dict.pop("suggestions")

responses_dict = defaultdict(list)
for response in record.responses:
responses_key = f"{response.question_name}.responses"
responses_users_key = f"{responses_key}.users"

responses_dict[responses_key].append(response.value)
responses_dict[responses_users_key].append(str(response.user_id))

suggestions_dict = {}
for suggestion in record.suggestions:
suggestion_key = f"{suggestion.question_name}.suggestion"
suggestion_agent_key = f"{suggestion_key}.agent"
suggestion_score_key = f"{suggestion_key}.score"

suggestions_dict.update(
{
suggestion_key: suggestion.value,
suggestion_score_key: suggestion.score,
suggestion_agent_key: suggestion.agent,
}
)

record_dict.update({**responses_dict, **suggestions_dict})
return record_dict
58 changes: 58 additions & 0 deletions argilla/tests/unit/test_io/test_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from uuid import uuid4

import argilla as rg
from argilla.records._io import GenericIO


class TestGenericIO:
def test_to_list_flatten(self):
user_a, user_b, user_c = uuid4(), uuid4(), uuid4()

record = rg.Record(
fields={"field": "The field"},
metadata={"key": "value"},
responses=[
rg.Response(question_name="q1", value="value", user_id=user_a),
rg.Response(question_name="q2", value="value", user_id=user_a),
rg.Response(question_name="q2", value="value", user_id=user_b),
rg.Response(question_name="q1", value="value", user_id=user_c),
],
suggestions=[
rg.Suggestion(question_name="q1", value="value", score=0.1, agent="test"),
rg.Suggestion(question_name="q2", value="value", score=0.9),
],
)

records_list = GenericIO.to_list([record], flatten=True)
assert records_list == [
{
"id": str(record.id),
"_server_id": None,
"field": "The field",
"key": "value",
"q1.responses": ["value", "value"],
"q1.responses.users": [str(user_a), str(user_c)],
"q2.responses": ["value", "value"],
"q2.responses.users": [str(user_a), str(user_b)],
"q1.suggestion": "value",
"q1.suggestion.score": 0.1,
"q1.suggestion.agent": "test",
"q2.suggestion": "value",
"q2.suggestion.score": 0.9,
"q2.suggestion.agent": None,
}
]

0 comments on commit 120160d

Please sign in to comment.