diff --git a/argilla/docs/how_to_guides/query_export.md b/argilla/docs/how_to_guides/query_export.md index 9be482c797..c0f4ae1b13 100644 --- a/argilla/docs/how_to_guides/query_export.md +++ b/argilla/docs/how_to_guides/query_export.md @@ -134,7 +134,7 @@ workspace = client.workspaces("my_workspace") dataset = client.datasets(name="my_dataset", workspace=workspace) status_filter = rg.Query( - filter=rg.Filter(("status", "==", "submitted")) + filter=rg.Filter(("response.status", "==", "submitted")) ) filtered_records = list(dataset.records(status_filter)) diff --git a/argilla/docs/how_to_guides/record.md b/argilla/docs/how_to_guides/record.md index 95590c9724..7e30550839 100644 --- a/argilla/docs/how_to_guides/record.md +++ b/argilla/docs/how_to_guides/record.md @@ -484,7 +484,7 @@ dataset.records.delete(records=records_to_delete) ```python status_filter = rg.Query( - filter = rg.Filter(("status", "==", "pending")) + filter = rg.Filter(("response.status", "==", "pending")) ) records_to_delete = list(dataset.records(status_filter)) diff --git a/argilla/src/argilla/records/_dataset_records.py b/argilla/src/argilla/records/_dataset_records.py index 2d1c351e68..a52a598c96 100644 --- a/argilla/src/argilla/records/_dataset_records.py +++ b/argilla/src/argilla/records/_dataset_records.py @@ -104,7 +104,7 @@ def _fetch_from_server_with_list(self) -> List[RecordModel]: def _fetch_from_server_with_search(self) -> List[RecordModel]: search_items, total = self.__client.api.records.search( dataset_id=self.__dataset.id, - query=self.__query.model, + query=self.__query.api_model(), limit=self.__batch_size, offset=self.__offset, with_responses=self.__with_responses, diff --git a/argilla/src/argilla/records/_search.py b/argilla/src/argilla/records/_search.py index adc56b5750..a4a465bd29 100644 --- a/argilla/src/argilla/records/_search.py +++ b/argilla/src/argilla/records/_search.py @@ -32,8 +32,7 @@ class Condition(Tuple[str, str, Any]): """This class is used to map user conditions to the internal filter models""" - @property - def model(self) -> FilterModel: + def api_model(self) -> FilterModel: field, operator, value = self field = field.strip() @@ -55,7 +54,7 @@ def model(self) -> FilterModel: def _extract_filter_scope(field: str) -> ScopeModel: field = field.strip() - if field == "status": + if field == "response.status": return ResponseFilterScopeModel(property="status") elif "metadata" in field: _, md_property = field.split(".") @@ -70,6 +69,8 @@ def _extract_filter_scope(field: str) -> ScopeModel: question, _ = field.split(".") return ResponseFilterScopeModel(question=question) else: # Question field -> Suggestion + # TODO: The default path would be raise an error instead of consider suggestions by default + # (can be confusing) return SuggestionFilterScopeModel(question=field) @@ -91,9 +92,8 @@ def __init__(self, conditions: Union[List[Tuple[str, str, Any]], Tuple[str, str, conditions = [conditions] self.conditions = [Condition(condition) for condition in conditions] - @property - def model(self) -> AndFilterModel: - return AndFilterModel.model_validate({"and": [condition.model for condition in self.conditions]}) + def api_model(self) -> AndFilterModel: + return AndFilterModel.model_validate({"and": [condition.api_model() for condition in self.conditions]}) class Query: @@ -112,8 +112,7 @@ def __init__(self, *, query: Union[str, None] = None, filter: Union[Filter, None self.query = query self.filter = filter - @property - def model(self) -> SearchQueryModel: + def api_model(self) -> SearchQueryModel: model = SearchQueryModel() if self.query is not None: @@ -121,7 +120,7 @@ def model(self) -> SearchQueryModel: model.query = QueryModel(text=text_query) if self.filter is not None: - model.filters = self.filter.model + model.filters = self.filter.api_model() return model diff --git a/argilla/tests/unit/test_search/test_filters.py b/argilla/tests/unit/test_search/test_filters.py new file mode 100644 index 0000000000..fbc9e077c9 --- /dev/null +++ b/argilla/tests/unit/test_search/test_filters.py @@ -0,0 +1,30 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argilla.records import Filter + + +class TestFilters: + def test_filter_by_responses_status(self): + test_filter = Filter(("response.status", "in", ["submitted", "discard"])) + assert test_filter.api_model().model_dump(by_alias=True) == { + "type": "and", + "and": [ + { + "scope": {"entity": "response", "property": "status", "question": None}, + "type": "terms", + "values": ["submitted", "discard"], + } + ], + }