From f55e4bed8acab5ed15dca2cd7f9fabbec499bff3 Mon Sep 17 00:00:00 2001 From: Sai Kiran Bonu Date: Thu, 9 Jan 2025 01:37:13 +0530 Subject: [PATCH 1/7] fix to return similarity score of records --- .../src/argilla/records/_dataset_records.py | 23 +++++++++------- argilla/src/argilla/records/_io/_datasets.py | 4 +-- argilla/src/argilla/records/_io/_generic.py | 26 ++++++++++++------- argilla/src/argilla/records/_io/_json.py | 13 +++++++--- 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/argilla/src/argilla/records/_dataset_records.py b/argilla/src/argilla/records/_dataset_records.py index a166c49f34..77ad8bc371 100644 --- a/argilla/src/argilla/records/_dataset_records.py +++ b/argilla/src/argilla/records/_dataset_records.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from uuid import UUID from enum import Enum @@ -104,27 +104,30 @@ def _fetch_next_batch(self) -> None: self.__records_batch = list(self._list()) self.__offset += len(self.__records_batch) - def _list(self) -> Sequence[Record]: - for record_model in self._fetch_from_server(): - yield Record.from_model(model=record_model, dataset=self.__dataset) + def _list(self) -> Sequence[Tuple[Record, Optional[float]]]: + for record_model, score in self._fetch_from_server(): + if score is not None: + yield Record.from_model(model=record_model, dataset=self.__dataset), score + else: + yield Record.from_model(model=record_model, dataset=self.__dataset) - def _fetch_from_server(self) -> List[RecordModel]: + def _fetch_from_server(self) -> List[Tuple[RecordModel, Optional[float]]]: if not self.__client.api.datasets.exists(self.__dataset.id): warnings.warn(f"Dataset {self.__dataset.id!r} does not exist on the server. Skipping...") return [] return self._fetch_from_server_with_search() if self._is_search_query() else self._fetch_from_server_with_list() - def _fetch_from_server_with_list(self) -> List[RecordModel]: - return self.__client.api.records.list( + def _fetch_from_server_with_list(self) -> List[Tuple[RecordModel, None]]: + return [(record_model, None) for record_model in self.__client.api.records.list( dataset_id=self.__dataset.id, limit=self.__batch_size, offset=self.__offset, with_responses=self.__with_responses, with_suggestions=self.__with_suggestions, with_vectors=self.__with_vectors, - ) + )] - def _fetch_from_server_with_search(self) -> List[RecordModel]: + def _fetch_from_server_with_search(self) -> List[Tuple[RecordModel, Optional[float]]]: search_items, total = self.__client.api.records.search( dataset_id=self.__dataset.id, query=self.__query.api_model(), @@ -134,7 +137,7 @@ def _fetch_from_server_with_search(self) -> List[RecordModel]: with_suggestions=self.__with_suggestions, with_vectors=self.__with_vectors, ) - return [record_model for record_model, _ in search_items] + return search_items def _is_search_query(self) -> bool: return self.__query.has_search() diff --git a/argilla/src/argilla/records/_io/_datasets.py b/argilla/src/argilla/records/_io/_datasets.py index 975816cca2..32da8c60d8 100644 --- a/argilla/src/argilla/records/_io/_datasets.py +++ b/argilla/src/argilla/records/_io/_datasets.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, Tuple from datasets import Dataset as HFDataset, Sequence from datasets import Image, ClassLabel, Value @@ -194,7 +194,7 @@ def _is_hf_dataset(dataset: Any) -> bool: return isinstance(dataset, HFDataset) @staticmethod - def to_datasets(records: List["Record"], dataset: "Dataset") -> HFDataset: + def to_datasets(records: List[Tuple["Record", Optional[float]]], dataset: "Dataset") -> HFDataset: """ Export the records to a Hugging Face dataset. diff --git a/argilla/src/argilla/records/_io/_generic.py b/argilla/src/argilla/records/_io/_generic.py index b1a1b8ef28..a9b021d290 100644 --- a/argilla/src/argilla/records/_io/_generic.py +++ b/argilla/src/argilla/records/_io/_generic.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict -from typing import Any, Dict, List, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union if TYPE_CHECKING: from argilla import Record @@ -24,7 +24,7 @@ class GenericIO: It handles methods for exporting records to generic python formats.""" @staticmethod - def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Union[str, float, int, list]]]: + def to_list(records: List[Tuple["Record", Optional[float]]], flatten: bool = False) -> List[Dict[str, Union[str, float, int, list]]]: """Export records to a list of dictionaries with either names or record index as keys. Args: flatten (bool): The structure of the exported dictionary. @@ -35,8 +35,10 @@ def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Un """ records_schema = set() dataset_records: list = [] - for record in records: + for record, score in records: record_dict = GenericIO._record_to_dict(record=record, flatten=flatten) + if score is not None: # Include score only if it exists + record_dict["score"] = score records_schema.update([k for k in record_dict]) dataset_records.append(record_dict) @@ -48,8 +50,8 @@ def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Un @classmethod def to_dict( - cls, records: List["Record"], flatten: bool = False, orient: str = "names" - ) -> Dict[str, Union[str, float, int, list]]: + cls, records: List[Tuple["Record", Optional[float]]], flatten: bool = False, orient: str = "names" +) -> Dict[str, Union[str, float, int, list]]: """Export records to a dictionary with either names or record index as keys. Args: flatten (bool): The structure of the exported dictionary. @@ -63,13 +65,19 @@ def to_dict( """ if orient == "names": dataset_records: dict = defaultdict(list) - for record in cls.to_list(records, flatten=flatten): - for key, value in record.items(): + for record, score in records: + record_dict = GenericIO._record_to_dict(record=record, flatten=flatten) + if score is not None: + record_dict["score"] = score + for key, value in record_dict.items(): dataset_records[key].append(value) elif orient == "index": dataset_records: dict = {} - for record in records: - dataset_records[record.id] = GenericIO._record_to_dict(record=record, flatten=flatten) + for record, score in records: + record_dict = GenericIO._record_to_dict(record=record, flatten=flatten) + if score is not None: + record_dict["score"] = score + dataset_records[record.id] = record_dict else: raise ValueError(f"Invalid value for orient parameter: {orient}") return dict(dataset_records) diff --git a/argilla/src/argilla/records/_io/_json.py b/argilla/src/argilla/records/_io/_json.py index 0bf6208726..1144ce4395 100644 --- a/argilla/src/argilla/records/_io/_json.py +++ b/argilla/src/argilla/records/_io/_json.py @@ -13,7 +13,7 @@ # limitations under the License. import json from pathlib import Path -from typing import List, Union +from typing import List, Optional, Tuple, Union from argilla.records._resource import Record from argilla.records._io import GenericIO @@ -21,7 +21,7 @@ class JsonIO: @staticmethod - def to_json(records: List["Record"], path: Union[Path, str]) -> Path: + def to_json(records: List[Tuple["Record", Optional[float]]], path: Union[Path, str]) -> Path: """ Export the records to a file on disk. This is a convenient shortcut for dataset.records(...).to_disk(). @@ -37,7 +37,14 @@ def to_json(records: List["Record"], path: Union[Path, str]) -> Path: path = Path(path) if path.exists(): raise FileExistsError(f"File {path} already exists.") - record_dicts = GenericIO.to_list(records, flatten=False) + + record_dicts = [] + for record, score in records: + record_dict = GenericIO._record_to_dict(record=record, flatten=False) + if score is not None: + record_dict["score"] = score + record_dicts.append(record_dict) + with open(path, "w") as f: json.dump(record_dicts, f) return path From 6d4ec2e34eb08df5372ab98d173456b5177289ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Jan 2025 21:12:21 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../src/argilla/records/_dataset_records.py | 19 +++++++++++-------- argilla/src/argilla/records/_io/_generic.py | 6 ++++-- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/argilla/src/argilla/records/_dataset_records.py b/argilla/src/argilla/records/_dataset_records.py index 77ad8bc371..44d0029120 100644 --- a/argilla/src/argilla/records/_dataset_records.py +++ b/argilla/src/argilla/records/_dataset_records.py @@ -118,14 +118,17 @@ def _fetch_from_server(self) -> List[Tuple[RecordModel, Optional[float]]]: return self._fetch_from_server_with_search() if self._is_search_query() else self._fetch_from_server_with_list() def _fetch_from_server_with_list(self) -> List[Tuple[RecordModel, None]]: - return [(record_model, None) for record_model in self.__client.api.records.list( - dataset_id=self.__dataset.id, - limit=self.__batch_size, - offset=self.__offset, - with_responses=self.__with_responses, - with_suggestions=self.__with_suggestions, - with_vectors=self.__with_vectors, - )] + return [ + (record_model, None) + for record_model in self.__client.api.records.list( + dataset_id=self.__dataset.id, + limit=self.__batch_size, + offset=self.__offset, + with_responses=self.__with_responses, + with_suggestions=self.__with_suggestions, + with_vectors=self.__with_vectors, + ) + ] def _fetch_from_server_with_search(self) -> List[Tuple[RecordModel, Optional[float]]]: search_items, total = self.__client.api.records.search( diff --git a/argilla/src/argilla/records/_io/_generic.py b/argilla/src/argilla/records/_io/_generic.py index a9b021d290..7b039c2bd2 100644 --- a/argilla/src/argilla/records/_io/_generic.py +++ b/argilla/src/argilla/records/_io/_generic.py @@ -24,7 +24,9 @@ class GenericIO: It handles methods for exporting records to generic python formats.""" @staticmethod - def to_list(records: List[Tuple["Record", Optional[float]]], flatten: bool = False) -> List[Dict[str, Union[str, float, int, list]]]: + def to_list( + records: List[Tuple["Record", Optional[float]]], flatten: bool = False + ) -> List[Dict[str, Union[str, float, int, list]]]: """Export records to a list of dictionaries with either names or record index as keys. Args: flatten (bool): The structure of the exported dictionary. @@ -51,7 +53,7 @@ def to_list(records: List[Tuple["Record", Optional[float]]], flatten: bool = Fal @classmethod def to_dict( cls, records: List[Tuple["Record", Optional[float]]], flatten: bool = False, orient: str = "names" -) -> Dict[str, Union[str, float, int, list]]: + ) -> Dict[str, Union[str, float, int, list]]: """Export records to a dictionary with either names or record index as keys. Args: flatten (bool): The structure of the exported dictionary. From 47383f76d4150a5651b7ebe1a7933dc96b77b74b Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 20 Jan 2025 13:07:28 +0100 Subject: [PATCH 3/7] return similarity score only when the similar filter is enabled to keep backward comp. as much as possible --- .../src/argilla/records/_dataset_records.py | 50 ++++++++++--------- argilla/src/argilla/records/_io/_datasets.py | 2 +- argilla/src/argilla/records/_io/_generic.py | 28 +++++------ argilla/src/argilla/records/_io/_json.py | 8 ++- argilla/src/argilla/records/_search.py | 5 +- 5 files changed, 48 insertions(+), 45 deletions(-) diff --git a/argilla/src/argilla/records/_dataset_records.py b/argilla/src/argilla/records/_dataset_records.py index 44d0029120..73a4acc7cc 100644 --- a/argilla/src/argilla/records/_dataset_records.py +++ b/argilla/src/argilla/records/_dataset_records.py @@ -86,7 +86,7 @@ def _limit_reached(self) -> bool: return False return self.__limit <= 0 - def _next_record(self) -> Record: + def _next_record(self) -> Union[Record, Tuple[Record, float]]: if self._limit_reached() or self._no_records(): raise StopIteration() @@ -104,33 +104,35 @@ def _fetch_next_batch(self) -> None: self.__records_batch = list(self._list()) self.__offset += len(self.__records_batch) - def _list(self) -> Sequence[Tuple[Record, Optional[float]]]: - for record_model, score in self._fetch_from_server(): - if score is not None: - yield Record.from_model(model=record_model, dataset=self.__dataset), score - else: - yield Record.from_model(model=record_model, dataset=self.__dataset) - - def _fetch_from_server(self) -> List[Tuple[RecordModel, Optional[float]]]: + def _list(self) -> Sequence[Union[Record, Tuple[Record, float]]]: if not self.__client.api.datasets.exists(self.__dataset.id): warnings.warn(f"Dataset {self.__dataset.id!r} does not exist on the server. Skipping...") return [] - return self._fetch_from_server_with_search() if self._is_search_query() else self._fetch_from_server_with_list() - - def _fetch_from_server_with_list(self) -> List[Tuple[RecordModel, None]]: - return [ - (record_model, None) - for record_model in self.__client.api.records.list( - dataset_id=self.__dataset.id, - limit=self.__batch_size, - offset=self.__offset, - with_responses=self.__with_responses, - with_suggestions=self.__with_suggestions, - with_vectors=self.__with_vectors, - ) - ] - def _fetch_from_server_with_search(self) -> List[Tuple[RecordModel, Optional[float]]]: + if self._is_search_query(): + records = self._fetch_from_server_with_search() + + if self.__query.has_similar(): + for record_model, score in records: + yield Record.from_model(model=record_model, dataset=self.__dataset), score + else: + for record_model, _ in records: + yield Record.from_model(model=record_model, dataset=self.__dataset) + else: + for record_model in self._fetch_from_server_with_list(): + yield Record.from_model(model=record_model, dataset=self.__dataset) + + def _fetch_from_server_with_list(self) -> List[RecordModel]: + return self.__client.api.records.list( + dataset_id=self.__dataset.id, + limit=self.__batch_size, + offset=self.__offset, + with_responses=self.__with_responses, + with_suggestions=self.__with_suggestions, + with_vectors=self.__with_vectors, + ) + + def _fetch_from_server_with_search(self) -> List[Tuple[RecordModel, float]]: search_items, total = self.__client.api.records.search( dataset_id=self.__dataset.id, query=self.__query.api_model(), diff --git a/argilla/src/argilla/records/_io/_datasets.py b/argilla/src/argilla/records/_io/_datasets.py index 32da8c60d8..50466c066a 100644 --- a/argilla/src/argilla/records/_io/_datasets.py +++ b/argilla/src/argilla/records/_io/_datasets.py @@ -194,7 +194,7 @@ def _is_hf_dataset(dataset: Any) -> bool: return isinstance(dataset, HFDataset) @staticmethod - def to_datasets(records: List[Tuple["Record", Optional[float]]], dataset: "Dataset") -> HFDataset: + def to_datasets(records: List[Union["Record", Tuple["Record", float]]], dataset: "Dataset") -> HFDataset: """ Export the records to a Hugging Face dataset. diff --git a/argilla/src/argilla/records/_io/_generic.py b/argilla/src/argilla/records/_io/_generic.py index 7b039c2bd2..594718ab82 100644 --- a/argilla/src/argilla/records/_io/_generic.py +++ b/argilla/src/argilla/records/_io/_generic.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union if TYPE_CHECKING: from argilla import Record @@ -25,7 +25,7 @@ class GenericIO: @staticmethod def to_list( - records: List[Tuple["Record", Optional[float]]], flatten: bool = False + records: List[Union["Record", Tuple["Record", float]]], flatten: bool = False ) -> List[Dict[str, Union[str, float, int, list]]]: """Export records to a list of dictionaries with either names or record index as keys. Args: @@ -37,10 +37,8 @@ def to_list( """ records_schema = set() dataset_records: list = [] - for record, score in records: + for record in records: record_dict = GenericIO._record_to_dict(record=record, flatten=flatten) - if score is not None: # Include score only if it exists - record_dict["score"] = score records_schema.update([k for k in record_dict]) dataset_records.append(record_dict) @@ -52,7 +50,7 @@ def to_list( @classmethod def to_dict( - cls, records: List[Tuple["Record", Optional[float]]], flatten: bool = False, orient: str = "names" + cls, records: List[Union["Record", Tuple["Record", float]]], flatten: bool = False, orient: str = "names" ) -> Dict[str, Union[str, float, int, list]]: """Export records to a dictionary with either names or record index as keys. Args: @@ -67,18 +65,14 @@ def to_dict( """ if orient == "names": dataset_records: dict = defaultdict(list) - for record, score in records: + for record in records: record_dict = GenericIO._record_to_dict(record=record, flatten=flatten) - if score is not None: - record_dict["score"] = score for key, value in record_dict.items(): dataset_records[key].append(value) elif orient == "index": dataset_records: dict = {} - for record, score in records: + for record in records: record_dict = GenericIO._record_to_dict(record=record, flatten=flatten) - if score is not None: - record_dict["score"] = score dataset_records[record.id] = record_dict else: raise ValueError(f"Invalid value for orient parameter: {orient}") @@ -89,10 +83,10 @@ def to_dict( ############################ @staticmethod - def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]: + def _record_to_dict(record: Union["Record", Tuple["Record", float]], flatten=False) -> Dict[str, Any]: """Converts a Record object to a dictionary for export. Args: - record (Record): The Record object to convert. + record (Record): The Record object or the record and the linked score to convert. flatten (bool): The structure of the exported dictionary. - True: The record fields, metadata, suggestions and responses will be flattened so that their keys becomes the keys of the record dictionary, using @@ -102,6 +96,12 @@ def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]: Returns: A dictionary representing the record. """ + if isinstance(record, tuple): + record, score = record + + record_dict = GenericIO._record_to_dict(record, flatten) + record_dict["score"] = score + return record_dict record_dict = record.to_dict() if flatten: diff --git a/argilla/src/argilla/records/_io/_json.py b/argilla/src/argilla/records/_io/_json.py index 1144ce4395..a57b67e57d 100644 --- a/argilla/src/argilla/records/_io/_json.py +++ b/argilla/src/argilla/records/_io/_json.py @@ -13,7 +13,7 @@ # limitations under the License. import json from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union from argilla.records._resource import Record from argilla.records._io import GenericIO @@ -21,7 +21,7 @@ class JsonIO: @staticmethod - def to_json(records: List[Tuple["Record", Optional[float]]], path: Union[Path, str]) -> Path: + def to_json(records: List[Union["Record", Tuple["Record", float]]], path: Union[Path, str]) -> Path: """ Export the records to a file on disk. This is a convenient shortcut for dataset.records(...).to_disk(). @@ -39,10 +39,8 @@ def to_json(records: List[Tuple["Record", Optional[float]]], path: Union[Path, s raise FileExistsError(f"File {path} already exists.") record_dicts = [] - for record, score in records: + for record in records: record_dict = GenericIO._record_to_dict(record=record, flatten=False) - if score is not None: - record_dict["score"] = score record_dicts.append(record_dict) with open(path, "w") as f: diff --git a/argilla/src/argilla/records/_search.py b/argilla/src/argilla/records/_search.py index 15369eba63..1d82f459e5 100644 --- a/argilla/src/argilla/records/_search.py +++ b/argilla/src/argilla/records/_search.py @@ -175,7 +175,10 @@ def __init__( self.similar = similar def has_search(self) -> bool: - return bool(self.query or self.similar or self.filter) + return bool(self.query or self.has_similar() or self.filter) + + def has_similar(self) -> bool: + return bool(self.similar) def api_model(self) -> SearchQueryModel: model = SearchQueryModel() From df44e145e4240adb24ec0c923d78d3a5313f3e30 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 20 Jan 2025 13:08:40 +0100 Subject: [PATCH 4/7] fixing tests --- argilla/tests/integration/test_search_records.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/argilla/tests/integration/test_search_records.py b/argilla/tests/integration/test_search_records.py index f997366a25..60f5a6b94f 100644 --- a/argilla/tests/integration/test_search_records.py +++ b/argilla/tests/integration/test_search_records.py @@ -173,7 +173,7 @@ def test_search_records_by_similar_value(self, client: Argilla, dataset: Dataset ) ) assert len(records) == 1000 - assert records[0].id == str(data[3]["id"]) + assert records[0][0].id == str(data[3]["id"]) def test_search_records_by_least_similar_value(self, client: Argilla, dataset: Dataset): data = [ @@ -194,7 +194,7 @@ def test_search_records_by_least_similar_value(self, client: Argilla, dataset: D ) ) ) - assert records[-1].id == str(data[3]["id"]) + assert records[-1][0].id == str(data[3]["id"]) def test_search_records_by_similar_record(self, client: Argilla, dataset: Dataset): data = [ @@ -218,7 +218,7 @@ def test_search_records_by_similar_record(self, client: Argilla, dataset: Datase ) ) assert len(records) == 1000 - assert records[0].id != str(record.id) + assert records[0][0].id != str(record.id) def test_search_records_by_least_similar_record(self, client: Argilla, dataset: Dataset): data = [ @@ -241,4 +241,4 @@ def test_search_records_by_least_similar_record(self, client: Argilla, dataset: ) ) ) - assert all(r.id != str(record.id) for r in records) + assert all(r.id != str(record.id) for r, s in records) From 7ed3c8ac86c654466a5f3c709195e4a2256f6a92 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 20 Jan 2025 13:08:58 +0100 Subject: [PATCH 5/7] add io test --- argilla/tests/unit/test_io/test_generic.py | 35 ++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/argilla/tests/unit/test_io/test_generic.py b/argilla/tests/unit/test_io/test_generic.py index f0bed0fa82..e97fc08aa4 100644 --- a/argilla/tests/unit/test_io/test_generic.py +++ b/argilla/tests/unit/test_io/test_generic.py @@ -60,3 +60,38 @@ def test_to_list_flatten(self): "q2.suggestion.agent": None, } ] + + def test_records_tuple_to_list(self): + record = rg.Record(fields={"field": "The field"}, metadata={"key": "value"}) + + records_list = GenericIO.to_list( + [ + (record, 1.0), + (record, 0.5), + ] + ) + + assert records_list == [ + { + "id": str(record.id), + "status": record.status, + "_server_id": record._server_id, + "fields": {"field": "The field"}, + "metadata": {"key": "value"}, + "responses": {}, + "vectors": {}, + "suggestions": {}, + "score": 1.0, + }, + { + "id": str(record.id), + "status": record.status, + "_server_id": record._server_id, + "fields": {"field": "The field"}, + "metadata": {"key": "value"}, + "responses": {}, + "vectors": {}, + "suggestions": {}, + "score": 0.5, + }, + ] From 84c7dc9eec952c689debbb480c19f0d7f8d97a83 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 20 Jan 2025 13:11:59 +0100 Subject: [PATCH 6/7] chore: Update changelog --- argilla/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/argilla/CHANGELOG.md b/argilla/CHANGELOG.md index 94f72d3dcb..1b8d87b9e9 100644 --- a/argilla/CHANGELOG.md +++ b/argilla/CHANGELOG.md @@ -16,6 +16,10 @@ These are the section headers that we use: ## [Unreleased]() +### Added + +- Return similarity score when searching by similarity. ([#5778](https://github.com/argilla-io/argilla/pull/5778)) + ## [2.6.0](https://github.com/argilla-io/argilla/compare/v2.5.0...v2.6.0) ### Fixed From 02e0d37e750bf1bd66db3e2c5192b3abde41b49d Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 20 Jan 2025 14:26:38 +0100 Subject: [PATCH 7/7] restore original code --- argilla/src/argilla/records/_io/_generic.py | 8 +++----- argilla/src/argilla/records/_io/_json.py | 6 +----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/argilla/src/argilla/records/_io/_generic.py b/argilla/src/argilla/records/_io/_generic.py index 594718ab82..5e0796db4b 100644 --- a/argilla/src/argilla/records/_io/_generic.py +++ b/argilla/src/argilla/records/_io/_generic.py @@ -65,15 +65,13 @@ def to_dict( """ if orient == "names": dataset_records: dict = defaultdict(list) - for record in records: - record_dict = GenericIO._record_to_dict(record=record, flatten=flatten) - for key, value in record_dict.items(): + for record in cls.to_list(records, flatten=flatten): + for key, value in record.items(): dataset_records[key].append(value) elif orient == "index": dataset_records: dict = {} for record in records: - record_dict = GenericIO._record_to_dict(record=record, flatten=flatten) - dataset_records[record.id] = record_dict + dataset_records[record.id] = GenericIO._record_to_dict(record=record, flatten=flatten) else: raise ValueError(f"Invalid value for orient parameter: {orient}") return dict(dataset_records) diff --git a/argilla/src/argilla/records/_io/_json.py b/argilla/src/argilla/records/_io/_json.py index a57b67e57d..2f346f5bbb 100644 --- a/argilla/src/argilla/records/_io/_json.py +++ b/argilla/src/argilla/records/_io/_json.py @@ -38,11 +38,7 @@ def to_json(records: List[Union["Record", Tuple["Record", float]]], path: Union[ if path.exists(): raise FileExistsError(f"File {path} already exists.") - record_dicts = [] - for record in records: - record_dict = GenericIO._record_to_dict(record=record, flatten=False) - record_dicts.append(record_dict) - + record_dicts = GenericIO.to_list(records, flatten=False) with open(path, "w") as f: json.dump(record_dicts, f) return path