Skip to content

Commit

Permalink
[BUGFIX] argilla-server: Prevent error when sorting with opensearch (
Browse files Browse the repository at this point in the history
…#5297)

# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR fixes errors found when sorting records in UI using OpenSearch
as the search engine.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jul 24, 2024
1 parent 9f56d20 commit bac699a
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 5 deletions.
1 change: 1 addition & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ These are the section headers that we use:

- Fixed SQLite connection settings not working correctly due to an outdated conditional. ([#5149](https://github.com/argilla-io/argilla/pull/5149))
- Fixed errors when `allowed_workspaces` in `.oauth.yaml` file is empty. ([#5273](https://github.com/argilla-io/argilla/pull/5273))
- Fixed errors when sorting with OpenSearch search engine. ([#5297](https://github.com/argilla-io/argilla/pull/5297))

### Removed

Expand Down
4 changes: 2 additions & 2 deletions argilla-server/src/argilla_server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ async def compute_metrics_for(self, metadata_property: MetadataProperty) -> Meta
def build_elasticsearch_filter(self, filter: Filter) -> Dict[str, Any]:
if isinstance(filter, AndFilter):
filters = [self.build_elasticsearch_filter(f) for f in filter.filters]
return es_bool_query(should=filters, minimum_should_match=len(filters))
return es_bool_query(must=filters)

if isinstance(filter.scope, ResponseFilterScope):
return self._response_filter_to_es_filter(filter)
Expand Down Expand Up @@ -859,7 +859,7 @@ async def _index_search_request(
query: dict,
size: Optional[int] = None,
from_: Optional[int] = None,
sort: Optional[str] = None,
sort: Optional[dict] = None,
aggregations: Optional[dict] = None,
) -> dict:
"""Executes request for search documents on a index"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def _index_search_request(
query: dict,
size: Optional[int] = None,
from_: Optional[int] = None,
sort: Optional[str] = None,
sort: Optional[dict] = None,
aggregations: Optional[dict] = None,
) -> dict:
return await self.client.search(
Expand Down
6 changes: 4 additions & 2 deletions argilla-server/src/argilla_server/search_engine/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,22 @@ async def _index_search_request(
query: dict,
size: Optional[int] = None,
from_: Optional[int] = None,
sort: str = None,
sort: Optional[dict] = None,
aggregations: Optional[dict] = None,
) -> dict:
body = {"query": query}
if aggregations:
body["aggs"] = aggregations

if sort:
body["sort"] = sort

return await self.client.search(
index=index,
body=body,
from_=from_,
size=size,
_source=False,
sort=sort or "_score:desc,id:asc",
track_total_hits=True,
)

Expand Down
10 changes: 10 additions & 0 deletions argilla-server/tests/unit/search_engine/test_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
RangeFilter,
Order,
RecordFilterScope,
AndFilter,
)
from argilla_server.search_engine.commons import (
BaseElasticAndOpenSearchEngine,
Expand Down Expand Up @@ -704,6 +705,15 @@ async def test_search_with_response_status_filter_with_no_user(
(RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0.13, le=0.13), 1),
(RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0.0), 7),
(RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), le=12.03), 5),
(
AndFilter(
filters=[
TermsFilter(scope=MetadataFilterScope(metadata_property="label"), values=["negative"]),
RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), ge=3, le=4),
]
),
1,
),
],
)
async def test_search_with_metadata_filter(
Expand Down

0 comments on commit bac699a

Please sign in to comment.