Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to Return Similarity Scores Along with Records #5778

Merged
merged 8 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions argilla/src/argilla/records/_dataset_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from uuid import UUID
from enum import Enum

Expand Down Expand Up @@ -86,7 +86,7 @@ def _limit_reached(self) -> bool:
return False
return self.__limit <= 0

def _next_record(self) -> Record:
def _next_record(self) -> Union[Record, Tuple[Record, float]]:
if self._limit_reached() or self._no_records():
raise StopIteration()

Expand All @@ -104,15 +104,23 @@ def _fetch_next_batch(self) -> None:
self.__records_batch = list(self._list())
self.__offset += len(self.__records_batch)

def _list(self) -> Sequence[Record]:
for record_model in self._fetch_from_server():
yield Record.from_model(model=record_model, dataset=self.__dataset)

def _fetch_from_server(self) -> List[RecordModel]:
def _list(self) -> Sequence[Union[Record, Tuple[Record, float]]]:
if not self.__client.api.datasets.exists(self.__dataset.id):
warnings.warn(f"Dataset {self.__dataset.id!r} does not exist on the server. Skipping...")
return []
return self._fetch_from_server_with_search() if self._is_search_query() else self._fetch_from_server_with_list()

if self._is_search_query():
records = self._fetch_from_server_with_search()

if self.__query.has_similar():
for record_model, score in records:
yield Record.from_model(model=record_model, dataset=self.__dataset), score
else:
for record_model, _ in records:
yield Record.from_model(model=record_model, dataset=self.__dataset)
else:
for record_model in self._fetch_from_server_with_list():
yield Record.from_model(model=record_model, dataset=self.__dataset)

def _fetch_from_server_with_list(self) -> List[RecordModel]:
return self.__client.api.records.list(
Expand All @@ -124,7 +132,7 @@ def _fetch_from_server_with_list(self) -> List[RecordModel]:
with_vectors=self.__with_vectors,
)

def _fetch_from_server_with_search(self) -> List[RecordModel]:
def _fetch_from_server_with_search(self) -> List[Tuple[RecordModel, float]]:
search_items, total = self.__client.api.records.search(
dataset_id=self.__dataset.id,
query=self.__query.api_model(),
Expand All @@ -134,7 +142,7 @@ def _fetch_from_server_with_search(self) -> List[RecordModel]:
with_suggestions=self.__with_suggestions,
with_vectors=self.__with_vectors,
)
return [record_model for record_model, _ in search_items]
return search_items

def _is_search_query(self) -> bool:
return self.__query.has_search()
Expand Down
4 changes: 2 additions & 2 deletions argilla/src/argilla/records/_io/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, Tuple

from datasets import Dataset as HFDataset, Sequence
from datasets import Image, ClassLabel, Value
Expand Down Expand Up @@ -194,7 +194,7 @@ def _is_hf_dataset(dataset: Any) -> bool:
return isinstance(dataset, HFDataset)

@staticmethod
def to_datasets(records: List["Record"], dataset: "Dataset") -> HFDataset:
def to_datasets(records: List[Union["Record", Tuple["Record", float]]], dataset: "Dataset") -> HFDataset:
"""
Export the records to a Hugging Face dataset.

Expand Down
26 changes: 18 additions & 8 deletions argilla/src/argilla/records/_io/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from collections import defaultdict
from typing import Any, Dict, List, TYPE_CHECKING, Union
from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union

if TYPE_CHECKING:
from argilla import Record
Expand All @@ -24,7 +24,9 @@ class GenericIO:
It handles methods for exporting records to generic python formats."""

@staticmethod
def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Union[str, float, int, list]]]:
def to_list(
records: List[Union["Record", Tuple["Record", float]]], flatten: bool = False
) -> List[Dict[str, Union[str, float, int, list]]]:
"""Export records to a list of dictionaries with either names or record index as keys.
Args:
flatten (bool): The structure of the exported dictionary.
Expand All @@ -48,7 +50,7 @@ def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Un

@classmethod
def to_dict(
cls, records: List["Record"], flatten: bool = False, orient: str = "names"
cls, records: List[Union["Record", Tuple["Record", float]]], flatten: bool = False, orient: str = "names"
) -> Dict[str, Union[str, float, int, list]]:
"""Export records to a dictionary with either names or record index as keys.
Args:
Expand All @@ -63,13 +65,15 @@ def to_dict(
"""
if orient == "names":
dataset_records: dict = defaultdict(list)
for record in cls.to_list(records, flatten=flatten):
for key, value in record.items():
for record in records:
record_dict = GenericIO._record_to_dict(record=record, flatten=flatten)
for key, value in record_dict.items():
dataset_records[key].append(value)
elif orient == "index":
dataset_records: dict = {}
for record in records:
dataset_records[record.id] = GenericIO._record_to_dict(record=record, flatten=flatten)
record_dict = GenericIO._record_to_dict(record=record, flatten=flatten)
dataset_records[record.id] = record_dict
else:
raise ValueError(f"Invalid value for orient parameter: {orient}")
return dict(dataset_records)
Expand All @@ -79,10 +83,10 @@ def to_dict(
############################

@staticmethod
def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]:
def _record_to_dict(record: Union["Record", Tuple["Record", float]], flatten=False) -> Dict[str, Any]:
"""Converts a Record object to a dictionary for export.
Args:
record (Record): The Record object to convert.
record (Record): The Record object or the record and the linked score to convert.
flatten (bool): The structure of the exported dictionary.
- True: The record fields, metadata, suggestions and responses will be flattened
so that their keys becomes the keys of the record dictionary, using
Expand All @@ -92,6 +96,12 @@ def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]:
Returns:
A dictionary representing the record.
"""
if isinstance(record, tuple):
record, score = record

record_dict = GenericIO._record_to_dict(record, flatten)
record_dict["score"] = score
return record_dict

record_dict = record.to_dict()
if flatten:
Expand Down
11 changes: 8 additions & 3 deletions argilla/src/argilla/records/_io/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.
import json
from pathlib import Path
from typing import List, Union
from typing import List, Tuple, Union

from argilla.records._resource import Record
from argilla.records._io import GenericIO


class JsonIO:
@staticmethod
def to_json(records: List["Record"], path: Union[Path, str]) -> Path:
def to_json(records: List[Union["Record", Tuple["Record", float]]], path: Union[Path, str]) -> Path:
"""
Export the records to a file on disk. This is a convenient shortcut for dataset.records(...).to_disk().

Expand All @@ -37,7 +37,12 @@ def to_json(records: List["Record"], path: Union[Path, str]) -> Path:
path = Path(path)
if path.exists():
raise FileExistsError(f"File {path} already exists.")
record_dicts = GenericIO.to_list(records, flatten=False)

record_dicts = []
for record in records:
record_dict = GenericIO._record_to_dict(record=record, flatten=False)
record_dicts.append(record_dict)

with open(path, "w") as f:
json.dump(record_dicts, f)
return path
Expand Down
5 changes: 4 additions & 1 deletion argilla/src/argilla/records/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ def __init__(
self.similar = similar

def has_search(self) -> bool:
return bool(self.query or self.similar or self.filter)
return bool(self.query or self.has_similar() or self.filter)

def has_similar(self) -> bool:
return bool(self.similar)

def api_model(self) -> SearchQueryModel:
model = SearchQueryModel()
Expand Down
8 changes: 4 additions & 4 deletions argilla/tests/integration/test_search_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_search_records_by_similar_value(self, client: Argilla, dataset: Dataset
)
)
assert len(records) == 1000
assert records[0].id == str(data[3]["id"])
assert records[0][0].id == str(data[3]["id"])

def test_search_records_by_least_similar_value(self, client: Argilla, dataset: Dataset):
data = [
Expand All @@ -194,7 +194,7 @@ def test_search_records_by_least_similar_value(self, client: Argilla, dataset: D
)
)
)
assert records[-1].id == str(data[3]["id"])
assert records[-1][0].id == str(data[3]["id"])

def test_search_records_by_similar_record(self, client: Argilla, dataset: Dataset):
data = [
Expand All @@ -218,7 +218,7 @@ def test_search_records_by_similar_record(self, client: Argilla, dataset: Datase
)
)
assert len(records) == 1000
assert records[0].id != str(record.id)
assert records[0][0].id != str(record.id)

def test_search_records_by_least_similar_record(self, client: Argilla, dataset: Dataset):
data = [
Expand All @@ -241,4 +241,4 @@ def test_search_records_by_least_similar_record(self, client: Argilla, dataset:
)
)
)
assert all(r.id != str(record.id) for r in records)
assert all(r.id != str(record.id) for r, s in records)
35 changes: 35 additions & 0 deletions argilla/tests/unit/test_io/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,38 @@ def test_to_list_flatten(self):
"q2.suggestion.agent": None,
}
]

def test_records_tuple_to_list(self):
record = rg.Record(fields={"field": "The field"}, metadata={"key": "value"})

records_list = GenericIO.to_list(
[
(record, 1.0),
(record, 0.5),
]
)

assert records_list == [
{
"id": str(record.id),
"status": record.status,
"_server_id": record._server_id,
"fields": {"field": "The field"},
"metadata": {"key": "value"},
"responses": {},
"vectors": {},
"suggestions": {},
"score": 1.0,
},
{
"id": str(record.id),
"status": record.status,
"_server_id": record._server_id,
"fields": {"field": "The field"},
"metadata": {"key": "value"},
"responses": {},
"vectors": {},
"suggestions": {},
"score": 0.5,
},
]
Loading