Skip to content

Commit

Permalink
Correct NeuralQueryBuilder doEquals() and doHashCode(). (#1045)
Browse files Browse the repository at this point in the history
Signed-off-by: Bo Zhang <[email protected]>
  • Loading branch information
bzhangam authored and martin-gaievski committed Jan 10, 2025
1 parent 5ecf7fc commit 477a9ec
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 136 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
- Update NeuralQueryBuilder doEquals() and doHashCode() to cater the missing parameters information ([#1045](https://github.com/opensearch-project/neural-search/pull/1045)).
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.opensearch.knn.index.query.KNNQueryBuilder.MAX_DISTANCE_FIELD;
import static org.opensearch.knn.index.query.KNNQueryBuilder.MIN_SCORE_FIELD;
import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_IMAGE_FIELD;

/**
* A util class which holds the logic to determine the min version supported by the request parameters
Expand All @@ -22,12 +23,14 @@ public final class MinClusterVersionUtil {

private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
private static final Version MINIMAL_SUPPORTED_VERSION_QUERY_IMAGE_FIX = Version.V_2_19_0;

// Note this minimal version will act as a override
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
.put(MODEL_ID_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID)
.put(MAX_DISTANCE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH)
.put(MIN_SCORE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH)
.put(QUERY_IMAGE_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_QUERY_IMAGE_FIX)
.build();

public static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu

private final List<QueryBuilder> queries = new ArrayList<>();

private String fieldName;

static final int MAX_NUMBER_OF_SUB_QUERIES = 5;

public HybridQueryBuilder(StreamInput in) throws IOException {
Expand Down Expand Up @@ -255,7 +253,6 @@ protected boolean doEquals(HybridQueryBuilder obj) {
return false;
}
EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(fieldName, obj.fieldName);
equalsBuilder.append(queries, obj.queries);
return equalsBuilder.isEquals();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
Expand Down Expand Up @@ -76,8 +76,7 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder>
@VisibleForTesting
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");

@VisibleForTesting
static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image");
public static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image");

public static final ParseField MODEL_ID_FIELD = new ParseField("model_id");

Expand Down Expand Up @@ -236,7 +235,16 @@ public static NeuralQueryBuilder.Builder builder() {
public NeuralQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
// The query image field was introduced since v2.11.0 through the
// https://github.com/opensearch-project/neural-search/pull/359 but at that time we didn't add it to
// NeuralQueryBuilder(StreamInput in) and doWriteTo(StreamOutput out) function. The fix will be
// introduced in v2.19.0 so we need this check for the backward compatibility.
if (isClusterOnOrAfterMinReqVersion(QUERY_IMAGE_FIELD.getPreferredName())) {
this.queryText = in.readOptionalString();
this.queryImage = in.readOptionalString();
} else {
this.queryText = in.readString();
}
// If cluster version is on or after 2.11 then default model Id support is enabled
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
this.modelId = in.readOptionalString();
Expand Down Expand Up @@ -265,7 +273,16 @@ public NeuralQueryBuilder(StreamInput in) throws IOException {
@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(this.fieldName);
out.writeString(this.queryText);
// The query image field was introduced since v2.11.0 through the
// https://github.com/opensearch-project/neural-search/pull/359 but at that time we didn't add it to
// NeuralQueryBuilder(StreamInput in) and doWriteTo(StreamOutput out) function. The fix will be
// introduced in v2.19.0 so we need this check for the backward compatibility.
if (isClusterOnOrAfterMinReqVersion(QUERY_IMAGE_FIELD.getPreferredName())) {
out.writeOptionalString(this.queryText);
out.writeOptionalString(this.queryImage);
} else {
out.writeString(this.queryText);
}
// If cluster version is on or after 2.11 then default model Id support is enabled
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
out.writeOptionalString(this.modelId);
Expand All @@ -285,6 +302,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (isClusterOnOrAfterMinReqVersion(EXPAND_NESTED_FIELD.getPreferredName())) {
out.writeOptionalBoolean(this.expandNested);
}

if (isClusterOnOrAfterMinReqVersion(METHOD_PARAMS_FIELD.getPreferredName())) {
MethodParametersParser.streamOutput(out, methodParameters, MinClusterVersionUtil::isClusterOnOrAfterMinReqVersion);
}
Expand All @@ -295,7 +313,12 @@ protected void doWriteTo(StreamOutput out) throws IOException {
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject(NAME);
xContentBuilder.startObject(fieldName);
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
if (Objects.nonNull(queryText)) {
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
}
if (Objects.nonNull(queryImage)) {
xContentBuilder.field(QUERY_IMAGE_FIELD.getPreferredName(), queryImage);
}
if (Objects.nonNull(modelId)) {
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
Expand Down Expand Up @@ -501,15 +524,39 @@ protected boolean doEquals(NeuralQueryBuilder obj) {
EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(fieldName, obj.fieldName);
equalsBuilder.append(queryText, obj.queryText);
equalsBuilder.append(queryImage, obj.queryImage);
equalsBuilder.append(modelId, obj.modelId);
equalsBuilder.append(k, obj.k);
equalsBuilder.append(maxDistance, obj.maxDistance);
equalsBuilder.append(minScore, obj.minScore);
equalsBuilder.append(expandNested, obj.expandNested);
equalsBuilder.append(getVector(vectorSupplier), getVector(obj.vectorSupplier));
equalsBuilder.append(filter, obj.filter);
equalsBuilder.append(methodParameters, obj.methodParameters);
equalsBuilder.append(rescoreContext, obj.rescoreContext);
return equalsBuilder.isEquals();
}

@Override
protected int doHashCode() {
return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(k).toHashCode();
return Objects.hash(
fieldName,
queryText,
queryImage,
modelId,
k,
maxDistance,
minScore,
expandNested,
Arrays.hashCode(getVector(vectorSupplier)),
filter,
methodParameters,
rescoreContext
);
}

private float[] getVector(final Supplier<float[]> vectorSupplier) {
return Objects.isNull(vectorSupplier) ? null : vectorSupplier.get();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,6 @@ public void testStreams_whenWrittingToStream_thenSuccessful() {
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.k(K)
.vectorSupplier(TEST_VECTOR_SUPPLIER)
.build();

original.add(neuralQueryBuilder);
Expand Down
Loading

0 comments on commit 477a9ec

Please sign in to comment.