Skip to content

Commit

Permalink
Fix to Return Similarity Scores Along with Records (#5778)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

Closes #5777 

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->
Passes similar param in query method while fetching records and could
see returning similarity score.

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Sai Kiran Bonu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Francisco Aranda <[email protected]>
Co-authored-by: Paco Aranda <[email protected]>
  • Loading branch information
5 people authored Jan 20, 2025
1 parent 8efa39c commit dc3deab
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 24 deletions.
4 changes: 4 additions & 0 deletions argilla/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ These are the section headers that we use:

## [Unreleased]()

### Added

- Return similarity score when searching by similarity. ([#5778](https://github.com/argilla-io/argilla/pull/5778))

## [2.6.0](https://github.com/argilla-io/argilla/compare/v2.5.0...v2.6.0)

### Fixed
Expand Down
28 changes: 18 additions & 10 deletions argilla/src/argilla/records/_dataset_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from uuid import UUID
from enum import Enum

Expand Down Expand Up @@ -86,7 +86,7 @@ def _limit_reached(self) -> bool:
return False
return self.__limit <= 0

def _next_record(self) -> Record:
def _next_record(self) -> Union[Record, Tuple[Record, float]]:
if self._limit_reached() or self._no_records():
raise StopIteration()

Expand All @@ -104,15 +104,23 @@ def _fetch_next_batch(self) -> None:
self.__records_batch = list(self._list())
self.__offset += len(self.__records_batch)

def _list(self) -> Sequence[Record]:
for record_model in self._fetch_from_server():
yield Record.from_model(model=record_model, dataset=self.__dataset)

def _fetch_from_server(self) -> List[RecordModel]:
def _list(self) -> Sequence[Union[Record, Tuple[Record, float]]]:
if not self.__client.api.datasets.exists(self.__dataset.id):
warnings.warn(f"Dataset {self.__dataset.id!r} does not exist on the server. Skipping...")
return []
return self._fetch_from_server_with_search() if self._is_search_query() else self._fetch_from_server_with_list()

if self._is_search_query():
records = self._fetch_from_server_with_search()

if self.__query.has_similar():
for record_model, score in records:
yield Record.from_model(model=record_model, dataset=self.__dataset), score
else:
for record_model, _ in records:
yield Record.from_model(model=record_model, dataset=self.__dataset)
else:
for record_model in self._fetch_from_server_with_list():
yield Record.from_model(model=record_model, dataset=self.__dataset)

def _fetch_from_server_with_list(self) -> List[RecordModel]:
return self.__client.api.records.list(
Expand All @@ -124,7 +132,7 @@ def _fetch_from_server_with_list(self) -> List[RecordModel]:
with_vectors=self.__with_vectors,
)

def _fetch_from_server_with_search(self) -> List[RecordModel]:
def _fetch_from_server_with_search(self) -> List[Tuple[RecordModel, float]]:
search_items, total = self.__client.api.records.search(
dataset_id=self.__dataset.id,
query=self.__query.api_model(),
Expand All @@ -134,7 +142,7 @@ def _fetch_from_server_with_search(self) -> List[RecordModel]:
with_suggestions=self.__with_suggestions,
with_vectors=self.__with_vectors,
)
return [record_model for record_model, _ in search_items]
return search_items

def _is_search_query(self) -> bool:
return self.__query.has_search()
Expand Down
4 changes: 2 additions & 2 deletions argilla/src/argilla/records/_io/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, Tuple

from datasets import Dataset as HFDataset, Sequence
from datasets import Image, ClassLabel, Value
Expand Down Expand Up @@ -194,7 +194,7 @@ def _is_hf_dataset(dataset: Any) -> bool:
return isinstance(dataset, HFDataset)

@staticmethod
def to_datasets(records: List["Record"], dataset: "Dataset") -> HFDataset:
def to_datasets(records: List[Union["Record", Tuple["Record", float]]], dataset: "Dataset") -> HFDataset:
"""
Export the records to a Hugging Face dataset.
Expand Down
18 changes: 13 additions & 5 deletions argilla/src/argilla/records/_io/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from collections import defaultdict
from typing import Any, Dict, List, TYPE_CHECKING, Union
from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union

if TYPE_CHECKING:
from argilla import Record
Expand All @@ -24,7 +24,9 @@ class GenericIO:
It handles methods for exporting records to generic python formats."""

@staticmethod
def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Union[str, float, int, list]]]:
def to_list(
records: List[Union["Record", Tuple["Record", float]]], flatten: bool = False
) -> List[Dict[str, Union[str, float, int, list]]]:
"""Export records to a list of dictionaries with either names or record index as keys.
Args:
flatten (bool): The structure of the exported dictionary.
Expand All @@ -48,7 +50,7 @@ def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Un

@classmethod
def to_dict(
cls, records: List["Record"], flatten: bool = False, orient: str = "names"
cls, records: List[Union["Record", Tuple["Record", float]]], flatten: bool = False, orient: str = "names"
) -> Dict[str, Union[str, float, int, list]]:
"""Export records to a dictionary with either names or record index as keys.
Args:
Expand Down Expand Up @@ -79,10 +81,10 @@ def to_dict(
############################

@staticmethod
def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]:
def _record_to_dict(record: Union["Record", Tuple["Record", float]], flatten=False) -> Dict[str, Any]:
"""Converts a Record object to a dictionary for export.
Args:
record (Record): The Record object to convert.
record (Record): The Record object or the record and the linked score to convert.
flatten (bool): The structure of the exported dictionary.
- True: The record fields, metadata, suggestions and responses will be flattened
so that their keys becomes the keys of the record dictionary, using
Expand All @@ -92,6 +94,12 @@ def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]:
Returns:
A dictionary representing the record.
"""
if isinstance(record, tuple):
record, score = record

record_dict = GenericIO._record_to_dict(record, flatten)
record_dict["score"] = score
return record_dict

record_dict = record.to_dict()
if flatten:
Expand Down
5 changes: 3 additions & 2 deletions argilla/src/argilla/records/_io/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.
import json
from pathlib import Path
from typing import List, Union
from typing import List, Tuple, Union

from argilla.records._resource import Record
from argilla.records._io import GenericIO


class JsonIO:
@staticmethod
def to_json(records: List["Record"], path: Union[Path, str]) -> Path:
def to_json(records: List[Union["Record", Tuple["Record", float]]], path: Union[Path, str]) -> Path:
"""
Export the records to a file on disk. This is a convenient shortcut for dataset.records(...).to_disk().
Expand All @@ -37,6 +37,7 @@ def to_json(records: List["Record"], path: Union[Path, str]) -> Path:
path = Path(path)
if path.exists():
raise FileExistsError(f"File {path} already exists.")

record_dicts = GenericIO.to_list(records, flatten=False)
with open(path, "w") as f:
json.dump(record_dicts, f)
Expand Down
5 changes: 4 additions & 1 deletion argilla/src/argilla/records/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ def __init__(
self.similar = similar

def has_search(self) -> bool:
return bool(self.query or self.similar or self.filter)
return bool(self.query or self.has_similar() or self.filter)

def has_similar(self) -> bool:
return bool(self.similar)

def api_model(self) -> SearchQueryModel:
model = SearchQueryModel()
Expand Down
8 changes: 4 additions & 4 deletions argilla/tests/integration/test_search_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_search_records_by_similar_value(self, client: Argilla, dataset: Dataset
)
)
assert len(records) == 1000
assert records[0].id == str(data[3]["id"])
assert records[0][0].id == str(data[3]["id"])

def test_search_records_by_least_similar_value(self, client: Argilla, dataset: Dataset):
data = [
Expand All @@ -194,7 +194,7 @@ def test_search_records_by_least_similar_value(self, client: Argilla, dataset: D
)
)
)
assert records[-1].id == str(data[3]["id"])
assert records[-1][0].id == str(data[3]["id"])

def test_search_records_by_similar_record(self, client: Argilla, dataset: Dataset):
data = [
Expand All @@ -218,7 +218,7 @@ def test_search_records_by_similar_record(self, client: Argilla, dataset: Datase
)
)
assert len(records) == 1000
assert records[0].id != str(record.id)
assert records[0][0].id != str(record.id)

def test_search_records_by_least_similar_record(self, client: Argilla, dataset: Dataset):
data = [
Expand All @@ -241,4 +241,4 @@ def test_search_records_by_least_similar_record(self, client: Argilla, dataset:
)
)
)
assert all(r.id != str(record.id) for r in records)
assert all(r.id != str(record.id) for r, s in records)
35 changes: 35 additions & 0 deletions argilla/tests/unit/test_io/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,38 @@ def test_to_list_flatten(self):
"q2.suggestion.agent": None,
}
]

def test_records_tuple_to_list(self):
record = rg.Record(fields={"field": "The field"}, metadata={"key": "value"})

records_list = GenericIO.to_list(
[
(record, 1.0),
(record, 0.5),
]
)

assert records_list == [
{
"id": str(record.id),
"status": record.status,
"_server_id": record._server_id,
"fields": {"field": "The field"},
"metadata": {"key": "value"},
"responses": {},
"vectors": {},
"suggestions": {},
"score": 1.0,
},
{
"id": str(record.id),
"status": record.status,
"_server_id": record._server_id,
"fields": {"field": "The field"},
"metadata": {"key": "value"},
"responses": {},
"vectors": {},
"suggestions": {},
"score": 0.5,
},
]

0 comments on commit dc3deab

Please sign in to comment.