From dc3deab69eff479210ff7c4847c14f41fd451a9c Mon Sep 17 00:00:00 2001 From: Saikiranbonu1661 <141391289+Saikiranbonu1661@users.noreply.github.com> Date: Mon, 20 Jan 2025 19:04:12 +0530 Subject: [PATCH] Fix to Return Similarity Scores Along with Records (#5778) # Description Closes #5777 **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** Passes similar param in query method while fetching records and could see returning similarity score. **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Sai Kiran Bonu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Francisco Aranda Co-authored-by: Paco Aranda --- argilla/CHANGELOG.md | 4 +++ .../src/argilla/records/_dataset_records.py | 28 +++++++++------ argilla/src/argilla/records/_io/_datasets.py | 4 +-- argilla/src/argilla/records/_io/_generic.py | 18 +++++++--- argilla/src/argilla/records/_io/_json.py | 5 +-- argilla/src/argilla/records/_search.py | 5 ++- .../tests/integration/test_search_records.py | 8 ++--- argilla/tests/unit/test_io/test_generic.py | 35 +++++++++++++++++++ 8 files changed, 83 insertions(+), 24 deletions(-) diff --git a/argilla/CHANGELOG.md b/argilla/CHANGELOG.md index 94f72d3dcb..1b8d87b9e9 100644 --- a/argilla/CHANGELOG.md +++ b/argilla/CHANGELOG.md @@ -16,6 +16,10 @@ These are the section headers that we use: ## [Unreleased]() +### Added + +- Return similarity score when searching by similarity. ([#5778](https://github.com/argilla-io/argilla/pull/5778)) + ## [2.6.0](https://github.com/argilla-io/argilla/compare/v2.5.0...v2.6.0) ### Fixed diff --git a/argilla/src/argilla/records/_dataset_records.py b/argilla/src/argilla/records/_dataset_records.py index a166c49f34..73a4acc7cc 100644 --- a/argilla/src/argilla/records/_dataset_records.py +++ b/argilla/src/argilla/records/_dataset_records.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from uuid import UUID from enum import Enum @@ -86,7 +86,7 @@ def _limit_reached(self) -> bool: return False return self.__limit <= 0 - def _next_record(self) -> Record: + def _next_record(self) -> Union[Record, Tuple[Record, float]]: if self._limit_reached() or self._no_records(): raise StopIteration() @@ -104,15 +104,23 @@ def _fetch_next_batch(self) -> None: self.__records_batch = list(self._list()) self.__offset += len(self.__records_batch) - def _list(self) -> Sequence[Record]: - for record_model in self._fetch_from_server(): - yield Record.from_model(model=record_model, dataset=self.__dataset) - - def _fetch_from_server(self) -> List[RecordModel]: + def _list(self) -> Sequence[Union[Record, Tuple[Record, float]]]: if not self.__client.api.datasets.exists(self.__dataset.id): warnings.warn(f"Dataset {self.__dataset.id!r} does not exist on the server. Skipping...") return [] - return self._fetch_from_server_with_search() if self._is_search_query() else self._fetch_from_server_with_list() + + if self._is_search_query(): + records = self._fetch_from_server_with_search() + + if self.__query.has_similar(): + for record_model, score in records: + yield Record.from_model(model=record_model, dataset=self.__dataset), score + else: + for record_model, _ in records: + yield Record.from_model(model=record_model, dataset=self.__dataset) + else: + for record_model in self._fetch_from_server_with_list(): + yield Record.from_model(model=record_model, dataset=self.__dataset) def _fetch_from_server_with_list(self) -> List[RecordModel]: return self.__client.api.records.list( @@ -124,7 +132,7 @@ def _fetch_from_server_with_list(self) -> List[RecordModel]: with_vectors=self.__with_vectors, ) - def _fetch_from_server_with_search(self) -> List[RecordModel]: + def _fetch_from_server_with_search(self) -> List[Tuple[RecordModel, float]]: search_items, total = self.__client.api.records.search( dataset_id=self.__dataset.id, query=self.__query.api_model(), @@ -134,7 +142,7 @@ def _fetch_from_server_with_search(self) -> List[RecordModel]: with_suggestions=self.__with_suggestions, with_vectors=self.__with_vectors, ) - return [record_model for record_model, _ in search_items] + return search_items def _is_search_query(self) -> bool: return self.__query.has_search() diff --git a/argilla/src/argilla/records/_io/_datasets.py b/argilla/src/argilla/records/_io/_datasets.py index 975816cca2..50466c066a 100644 --- a/argilla/src/argilla/records/_io/_datasets.py +++ b/argilla/src/argilla/records/_io/_datasets.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, Tuple from datasets import Dataset as HFDataset, Sequence from datasets import Image, ClassLabel, Value @@ -194,7 +194,7 @@ def _is_hf_dataset(dataset: Any) -> bool: return isinstance(dataset, HFDataset) @staticmethod - def to_datasets(records: List["Record"], dataset: "Dataset") -> HFDataset: + def to_datasets(records: List[Union["Record", Tuple["Record", float]]], dataset: "Dataset") -> HFDataset: """ Export the records to a Hugging Face dataset. diff --git a/argilla/src/argilla/records/_io/_generic.py b/argilla/src/argilla/records/_io/_generic.py index b1a1b8ef28..5e0796db4b 100644 --- a/argilla/src/argilla/records/_io/_generic.py +++ b/argilla/src/argilla/records/_io/_generic.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict -from typing import Any, Dict, List, TYPE_CHECKING, Union +from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union if TYPE_CHECKING: from argilla import Record @@ -24,7 +24,9 @@ class GenericIO: It handles methods for exporting records to generic python formats.""" @staticmethod - def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Union[str, float, int, list]]]: + def to_list( + records: List[Union["Record", Tuple["Record", float]]], flatten: bool = False + ) -> List[Dict[str, Union[str, float, int, list]]]: """Export records to a list of dictionaries with either names or record index as keys. Args: flatten (bool): The structure of the exported dictionary. @@ -48,7 +50,7 @@ def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Un @classmethod def to_dict( - cls, records: List["Record"], flatten: bool = False, orient: str = "names" + cls, records: List[Union["Record", Tuple["Record", float]]], flatten: bool = False, orient: str = "names" ) -> Dict[str, Union[str, float, int, list]]: """Export records to a dictionary with either names or record index as keys. Args: @@ -79,10 +81,10 @@ def to_dict( ############################ @staticmethod - def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]: + def _record_to_dict(record: Union["Record", Tuple["Record", float]], flatten=False) -> Dict[str, Any]: """Converts a Record object to a dictionary for export. Args: - record (Record): The Record object to convert. + record (Record): The Record object or the record and the linked score to convert. flatten (bool): The structure of the exported dictionary. - True: The record fields, metadata, suggestions and responses will be flattened so that their keys becomes the keys of the record dictionary, using @@ -92,6 +94,12 @@ def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]: Returns: A dictionary representing the record. """ + if isinstance(record, tuple): + record, score = record + + record_dict = GenericIO._record_to_dict(record, flatten) + record_dict["score"] = score + return record_dict record_dict = record.to_dict() if flatten: diff --git a/argilla/src/argilla/records/_io/_json.py b/argilla/src/argilla/records/_io/_json.py index 0bf6208726..2f346f5bbb 100644 --- a/argilla/src/argilla/records/_io/_json.py +++ b/argilla/src/argilla/records/_io/_json.py @@ -13,7 +13,7 @@ # limitations under the License. import json from pathlib import Path -from typing import List, Union +from typing import List, Tuple, Union from argilla.records._resource import Record from argilla.records._io import GenericIO @@ -21,7 +21,7 @@ class JsonIO: @staticmethod - def to_json(records: List["Record"], path: Union[Path, str]) -> Path: + def to_json(records: List[Union["Record", Tuple["Record", float]]], path: Union[Path, str]) -> Path: """ Export the records to a file on disk. This is a convenient shortcut for dataset.records(...).to_disk(). @@ -37,6 +37,7 @@ def to_json(records: List["Record"], path: Union[Path, str]) -> Path: path = Path(path) if path.exists(): raise FileExistsError(f"File {path} already exists.") + record_dicts = GenericIO.to_list(records, flatten=False) with open(path, "w") as f: json.dump(record_dicts, f) diff --git a/argilla/src/argilla/records/_search.py b/argilla/src/argilla/records/_search.py index 15369eba63..1d82f459e5 100644 --- a/argilla/src/argilla/records/_search.py +++ b/argilla/src/argilla/records/_search.py @@ -175,7 +175,10 @@ def __init__( self.similar = similar def has_search(self) -> bool: - return bool(self.query or self.similar or self.filter) + return bool(self.query or self.has_similar() or self.filter) + + def has_similar(self) -> bool: + return bool(self.similar) def api_model(self) -> SearchQueryModel: model = SearchQueryModel() diff --git a/argilla/tests/integration/test_search_records.py b/argilla/tests/integration/test_search_records.py index f997366a25..60f5a6b94f 100644 --- a/argilla/tests/integration/test_search_records.py +++ b/argilla/tests/integration/test_search_records.py @@ -173,7 +173,7 @@ def test_search_records_by_similar_value(self, client: Argilla, dataset: Dataset ) ) assert len(records) == 1000 - assert records[0].id == str(data[3]["id"]) + assert records[0][0].id == str(data[3]["id"]) def test_search_records_by_least_similar_value(self, client: Argilla, dataset: Dataset): data = [ @@ -194,7 +194,7 @@ def test_search_records_by_least_similar_value(self, client: Argilla, dataset: D ) ) ) - assert records[-1].id == str(data[3]["id"]) + assert records[-1][0].id == str(data[3]["id"]) def test_search_records_by_similar_record(self, client: Argilla, dataset: Dataset): data = [ @@ -218,7 +218,7 @@ def test_search_records_by_similar_record(self, client: Argilla, dataset: Datase ) ) assert len(records) == 1000 - assert records[0].id != str(record.id) + assert records[0][0].id != str(record.id) def test_search_records_by_least_similar_record(self, client: Argilla, dataset: Dataset): data = [ @@ -241,4 +241,4 @@ def test_search_records_by_least_similar_record(self, client: Argilla, dataset: ) ) ) - assert all(r.id != str(record.id) for r in records) + assert all(r.id != str(record.id) for r, s in records) diff --git a/argilla/tests/unit/test_io/test_generic.py b/argilla/tests/unit/test_io/test_generic.py index f0bed0fa82..e97fc08aa4 100644 --- a/argilla/tests/unit/test_io/test_generic.py +++ b/argilla/tests/unit/test_io/test_generic.py @@ -60,3 +60,38 @@ def test_to_list_flatten(self): "q2.suggestion.agent": None, } ] + + def test_records_tuple_to_list(self): + record = rg.Record(fields={"field": "The field"}, metadata={"key": "value"}) + + records_list = GenericIO.to_list( + [ + (record, 1.0), + (record, 0.5), + ] + ) + + assert records_list == [ + { + "id": str(record.id), + "status": record.status, + "_server_id": record._server_id, + "fields": {"field": "The field"}, + "metadata": {"key": "value"}, + "responses": {}, + "vectors": {}, + "suggestions": {}, + "score": 1.0, + }, + { + "id": str(record.id), + "status": record.status, + "_server_id": record._server_id, + "fields": {"field": "The field"}, + "metadata": {"key": "value"}, + "responses": {}, + "vectors": {}, + "suggestions": {}, + "score": 0.5, + }, + ]