Skip to content

Commit

Permalink
[BUGFIX] argilla server: search on chat field when selected (#5504)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

When a chat field is selected for searches, the search returns no
results.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: José Francisco Calvo <[email protected]>
  • Loading branch information
frascuchon and jfcalvo authored Sep 17, 2024
1 parent 02c6e9d commit a6fb9ed
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
5 changes: 5 additions & 0 deletions argilla-server/src/argilla_server/models/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,11 @@ def is_ready(self):
def distribution_strategy(self) -> DatasetDistributionStrategy:
return DatasetDistributionStrategy(self.distribution["strategy"])

def field_by_name(self, name: str) -> Union["Field", None]:
for field in self.fields:
if field.name == name:
return field

def metadata_property_by_name(self, name: str) -> Union["MetadataProperty", None]:
for metadata_property in self.metadata_properties:
if metadata_property.name == name:
Expand Down
16 changes: 14 additions & 2 deletions argilla-server/src/argilla_server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def es_field_for_metadata_property(metadata_property: Union[str, MetadataPropert


def es_field_for_record_field(field_name: str) -> str:
return f"fields.{field_name or '*'}"
return f"fields.{field_name}"


def es_field_for_response_property(property: str) -> str:
Expand Down Expand Up @@ -647,7 +647,19 @@ def _build_text_query(dataset: Dataset, text: Optional[Union[TextQuery, str]] =
if isinstance(text, str):
text = TextQuery(q=text)

return es_simple_query_string(es_field_for_record_field(text.field), query=text.q)
if text.field:
field = dataset.field_by_name(text.field)
if field is None:
raise Exception(f"Field {text.field} not found in dataset {dataset.id}")

Check warning on line 653 in argilla-server/src/argilla_server/search_engine/commons.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/search_engine/commons.py#L653

Added line #L653 was not covered by tests

if field.is_chat:
field_name = f"{text.field}.*"
else:
field_name = text.field
else:
field_name = "*"

return es_simple_query_string(es_field_for_record_field(field_name), query=text.q)

@staticmethod
def _mapping_for_fields(fields: List[Field]) -> dict:
Expand Down
23 changes: 23 additions & 0 deletions argilla-server/tests/unit/search_engine/test_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
VectorFactory,
VectorSettingsFactory,
ImageFieldFactory,
ChatFieldFactory,
)


Expand Down Expand Up @@ -596,6 +597,28 @@ async def test_search_with_query_string(

assert scores == sorted_scores

async def test_search_for_chat_field(self, search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch):
chat_field = await ChatFieldFactory.create(name="field")

dataset = await DatasetFactory.create(fields=[chat_field])

records = await RecordFactory.create_batch(
size=2,
dataset=dataset,
fields={chat_field.name: [{"role": "user", "content": "Hello world"}, {"role": "bot", "content": "Hi"}]},
)

await refresh_dataset(dataset)
await refresh_records(records)

await search_engine.create_index(dataset)
await search_engine.index_records(dataset, records)

result = await search_engine.search(dataset, query=TextQuery(q="world", field=chat_field.name))

assert len(result.items) == 2
assert result.total == 2

@pytest.mark.parametrize(
"statuses, expected_items",
[
Expand Down

0 comments on commit a6fb9ed

Please sign in to comment.