Skip to content

Commit

Permalink
[ENHANCEMENT]: argilla-server: Enhance text search with simple quer…
Browse files Browse the repository at this point in the history
…y dsl (#5222)

# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR adds a tiny dsl to improve text queries based on [this
query](https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-simple-query-string-query.html#simple-query-string-syntax)

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Improvement (change adding some improvement to an existing
functionality)
- Documentation update

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Damián Pumar <[email protected]>
  • Loading branch information
3 people authored Aug 27, 2024
1 parent 60f3073 commit 6815e5c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,25 @@ declare namespace CSS {
};
}

const DSLChars = ["|", "+", "-", "*"];

export const useSearchTextHighlight = (fieldId: string) => {
const FIELD_ID_TO_HIGHLIGHT = `fields-content-${fieldId}`;
const HIGHLIGHT_CLASS = `search-text-highlight-${fieldId}`;

const scapeDSLChars = (value: string) => {
let output = value;

for (const char of DSLChars) {
output = output.replaceAll(char, " ");
}

return output
.split(" ")
.map((w) => w.trim())
.filter(Boolean);
};

const createRangesToHighlight = (
fieldComponent: HTMLElement,
searchText: string
Expand Down Expand Up @@ -89,7 +104,7 @@ export const useSearchTextHighlight = (fieldId: string) => {
};

const textNodes = getTextNodesUnder(fieldComponent);
const words = searchText.split(" ");
const words = scapeDSLChars(searchText);

for (const textNode of textNodes) {
for (const word of words) {
Expand Down
29 changes: 17 additions & 12 deletions argilla-server/src/argilla_server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,21 @@ def es_ids_query(ids: List[str]) -> dict:
return {"ids": {"values": ids}}


def es_simple_query_string(field_name: str, query: str) -> dict:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-simple-query-string-query.html
return {
"simple_query_string": {
"query": query,
"fields": [field_name],
"default_operator": "AND",
"analyze_wildcard": False,
"auto_generate_synonyms_phrase_query": False,
"fuzzy_max_expansions": 10,
"fuzzy_transpositions": False,
}
}


def es_nested_query(path: str, query: dict) -> dict:
return {
"nested": {
Expand Down Expand Up @@ -138,7 +153,7 @@ def es_field_for_metadata_property(metadata_property: Union[str, MetadataPropert


def es_field_for_record_field(field_name: str) -> str:
return f"fields.{field_name}"
return f"fields.{field_name or '*'}"


def es_field_for_response_property(property: str) -> str:
Expand Down Expand Up @@ -612,17 +627,7 @@ def _build_text_query(dataset: Dataset, text: Optional[Union[TextQuery, str]] =
if isinstance(text, str):
text = TextQuery(q=text)

if not text.field:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-multi-match-query.html
field_names = [
es_field_for_record_field(field.name)
for field in dataset.fields
if field.settings.get("type") == FieldType.text
]
return {"multi_match": {"query": text.q, "type": "cross_fields", "fields": field_names, "operator": "and"}}
else:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html
return {"match": {es_field_for_record_field(text.field): {"query": text.q, "operator": "and"}}}
return es_simple_query_string(es_field_for_record_field(text.field), query=text.q)

@staticmethod
def _mapping_for_fields(fields: List[Field]) -> dict:
Expand Down
6 changes: 6 additions & 0 deletions argilla-server/tests/unit/search_engine/test_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,9 @@ async def test_create_index_for_dataset_with_questions(
("00000", 1),
("card payment", 5),
("nothing", 0),
("cash | negative", 6), # OR
("cash + negative", 1), # AN
("-(cash | negative)", 3), # NOT
(TextQuery(q="card"), 5),
(TextQuery(q="account"), 1),
(TextQuery(q="payment"), 6),
Expand All @@ -558,6 +561,9 @@ async def test_create_index_for_dataset_with_questions(
(TextQuery(q="negative", field="label"), 4),
(TextQuery(q="00000", field="textId"), 1),
(TextQuery(q="card payment", field="text"), 5),
(TextQuery(q="cash | negative", field="text"), 3),
(TextQuery(q="cash + negative", field="text"), 0),
(TextQuery(q="-(cash | negative)", field="text"), 6),
],
)
async def test_search_with_query_string(
Expand Down

0 comments on commit 6815e5c

Please sign in to comment.