From 12d836e9b2fb07f677e0aa37ceea4ce15847bb03 Mon Sep 17 00:00:00 2001 From: swethakann Date: Wed, 3 Apr 2024 16:03:02 -0700 Subject: [PATCH] Add decay function support for MultiFunctionScoreQuery (#641) * Add decay function support for MultiFunctionScoreQuery --- .../main/proto/yelp/nrtsearch/search.proto | 30 +++ docs/queries/multi_function_score.rst | 33 ++- .../server/luceneserver/geo/GeoUtils.java | 10 + .../multifunction/DecayFilterFunction.java | 50 +++++ .../query/multifunction/DecayFunction.java | 52 +++++ .../ExponentialDecayFunction.java | 37 ++++ .../query/multifunction/FilterFunction.java | 8 + .../GeoPointDecayFilterFunction.java | 202 ++++++++++++++++++ .../multifunction/GuassianDecayFunction.java | 37 ++++ .../multifunction/LinearDecayFunction.java | 37 ++++ .../MultiFunctionScoreQueryTest.java | 184 +++++++++++++++- .../multifunction/registerFieldsMFSQ.json | 6 + 12 files changed, 684 insertions(+), 2 deletions(-) create mode 100644 src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java create mode 100644 src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java create mode 100644 src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/ExponentialDecayFunction.java create mode 100644 src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java create mode 100644 src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassianDecayFunction.java create mode 100644 src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/LinearDecayFunction.java diff --git a/clientlib/src/main/proto/yelp/nrtsearch/search.proto b/clientlib/src/main/proto/yelp/nrtsearch/search.proto index 28a85f004..54e9deb95 100644 --- a/clientlib/src/main/proto/yelp/nrtsearch/search.proto +++ b/clientlib/src/main/proto/yelp/nrtsearch/search.proto @@ -315,9 +315,39 @@ message MultiFunctionScoreQuery { oneof Function { // Produce score with score script definition Script script = 3; + // Produce score with a decay function + DecayFunction decayFunction = 4; } } + // Apply decay function to docs + message DecayFunction { + // Document field name to use + string fieldName = 1; + // Type of decay function to apply + DecayType decayType = 2; + // Origin point to calculate the distance + oneof Origin { + google.type.LatLng geoPoint = 3; + } + // Currently only distance based scale and offset units are supported + // Distance from origin + offset at which computed score will be equal to decay. Scale should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", "15 km", "5 m", "7 mi" + string scale = 4; + // Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set. Offset should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", "15 km", "5 m", "7 mi" + string offset = 5; + // Defines decay rate for scoring. Should be between (0, 1) + float decay = 6; + } + + enum DecayType { + // Exponential decay function + DECAY_TYPE_EXPONENTIAL = 0; + // Linear decay function + DECAY_TYPE_LINEAR = 1; + // Gaussian decay function + DECAY_TYPE_GUASSIAN = 2; + } + // How to combine multiple function scores to produce a final function score enum FunctionScoreMode { // Multiply weighted function scores together diff --git a/docs/queries/multi_function_score.rst b/docs/queries/multi_function_score.rst index 8b2839b37..782ed2cc3 100644 --- a/docs/queries/multi_function_score.rst +++ b/docs/queries/multi_function_score.rst @@ -22,9 +22,40 @@ Proto definition: oneof Function { // Produce score with score script definition Script script = 3; + // Produce score with a decay function + DecayFunction decayFunction = 4; } } + // Apply decay function to docs + message DecayFunction { + // Document field name to use + string fieldName = 1; + // Type of decay function to apply + DecayType decayType = 2; + // Origin point to calculate the distance + oneof Origin { + google.type.LatLng geoPoint = 3; + } + // Currently only distance based scale and offset units are supported + // Distance from origin + offset at which computed score will be equal to decay. Scale should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", 15 km", "5 m", "7 mi" + string scale = 4; + // Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set. Offset should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", 15 km", "5 m", "7 mi" + string offset = 5; + // Defines decay rate for scoring. Should be between (0, 1) + float decay = 6; + } + + enum DecayType { + // Exponential decay function + DECAY_TYPE_EXPONENTIAL = 0; + // Linear decay function + DECAY_TYPE_LINEAR = 1; + // Gaussian decay function + DECAY_TYPE_GUASSIAN = 2; + } + + // How to combine multiple function scores to produce a final function score enum FunctionScoreMode { // Multiply weighted function scores together @@ -55,4 +86,4 @@ Proto definition: float min_score = 5; // Determine minimal score is excluded or not. By default, it's false; bool min_excluded = 6; - } \ No newline at end of file + } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/geo/GeoUtils.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/geo/GeoUtils.java index c5c8deeb3..1cc39d409 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/geo/GeoUtils.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/geo/GeoUtils.java @@ -15,6 +15,8 @@ */ package com.yelp.nrtsearch.server.luceneserver.geo; +import org.apache.lucene.util.SloppyMath; + public class GeoUtils { private static final double KM_TO_M = 1000.0; @@ -64,4 +66,12 @@ public static double convertDistanceToADifferentUnit(double distanceNumber, Stri throw new IllegalArgumentException("Invalid unit " + unit); } } + + /** + * Return the distance (in meters) between 2 lat,lon geo points using the haversine method + * implemented by lucene + */ + public static double arcDistance(double lat1, double lon1, double lat2, double lon2) { + return SloppyMath.haversinMeters(lat1, lon1, lat2, lon2); + } } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java new file mode 100644 index 000000000..2791fc52a --- /dev/null +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java @@ -0,0 +1,50 @@ +/* + * Copyright 2024 Yelp Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction; + +import com.yelp.nrtsearch.server.grpc.MultiFunctionScoreQuery; +import org.apache.lucene.search.Query; + +public abstract class DecayFilterFunction extends FilterFunction { + + /** + * Constructor. + * + * @param filterQuery filter to use when applying this function, or null if none + * @param weight weight multiple to scale the function score + * @param decayFunction to score a document with a function that decays depending on the distance + * between an origin point and a numeric doc field value + */ + public DecayFilterFunction( + Query filterQuery, float weight, MultiFunctionScoreQuery.DecayFunction decayFunction) { + super(filterQuery, weight); + if (decayFunction.getDecay() <= 0 || decayFunction.getDecay() >= 1) { + throw new IllegalArgumentException( + "decay rate should be between (0, 1) but is " + decayFunction.getDecay()); + } + } + + protected DecayFunction getDecayType(MultiFunctionScoreQuery.DecayType decayType) { + return switch (decayType) { + case DECAY_TYPE_GUASSIAN -> new GuassianDecayFunction(); + case DECAY_TYPE_EXPONENTIAL -> new ExponentialDecayFunction(); + case DECAY_TYPE_LINEAR -> new LinearDecayFunction(); + default -> throw new IllegalArgumentException( + decayType + + " not supported. Only exponential, guassian and linear decay functions are supported"); + }; + } +} diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java new file mode 100644 index 000000000..af0205e71 --- /dev/null +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java @@ -0,0 +1,52 @@ +/* + * Copyright 2024 Yelp Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction; + +import org.apache.lucene.search.Explanation; + +public interface DecayFunction { + /** + * Computes the decayed score based on the provided distance, offset, and scale. + * + * @param distance the distance from a given origin point. + * @param offset the point after which the decay starts. + * @param scale scale factor that influences the rate of decay. This scale value is computed from + * the user given scale using the computeScale() method. + * @return the decayed score after applying the decay function + */ + double computeScore(double distance, double offset, double scale); + + /** + * Computes the adjusted scale based on a user given scale and decay rate. + * + * @param scale user given scale. + * @param decay decay rate that decides how the score decreases. + * @return adjusted scale which will be used by the computeScore() method. + */ + double computeScale(double scale, double decay); + + /** + * Provides an explanation for the computed score based on the given distance, offset, and scale. + * + * @param distance the distance from a given origin point. + * @param offset the point after which the decay starts. + * @param scale scale factor that influences the rate of decay. This scale value is computed from + * the user given scale using the computeScale() method. + * @return Explanation object that details the calculations involved in computing the decayed + * score. + */ + Explanation explainComputeScore(double distance, double offset, double scale); +} diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/ExponentialDecayFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/ExponentialDecayFunction.java new file mode 100644 index 000000000..e9556d648 --- /dev/null +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/ExponentialDecayFunction.java @@ -0,0 +1,37 @@ +/* + * Copyright 2024 Yelp Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction; + +import org.apache.lucene.search.Explanation; + +public class ExponentialDecayFunction implements DecayFunction { + @Override + public double computeScore(double distance, double offset, double scale) { + return Math.exp(scale * Math.max(0.0, distance - offset)); + } + + @Override + public double computeScale(double scale, double decay) { + return Math.log(decay) / scale; + } + + @Override + public Explanation explainComputeScore(double distance, double offset, double scale) { + return Explanation.match( + (float) computeScore(distance, offset, scale), + "exp(" + scale + " * max(0.0, " + distance + " - " + offset); + } +} diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java index b73b9e9a4..5ea1fde4b 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java @@ -88,6 +88,14 @@ public static FilterFunction build( indexState.docLookup); return new ScriptFilterFunction( filterQuery, weight, filterFunctionGrpc.getScript(), scriptSource); + case DECAYFUNCTION: + MultiFunctionScoreQuery.DecayFunction decayFunction = filterFunctionGrpc.getDecayFunction(); + if (decayFunction.hasGeoPoint()) { + return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction, indexState); + } else { + throw new IllegalArgumentException( + "Decay Function should contain a geoPoint for Origin field"); + } case FUNCTION_NOT_SET: return new WeightFilterFunction(filterQuery, weight); default: diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java new file mode 100644 index 000000000..98bbd8074 --- /dev/null +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java @@ -0,0 +1,202 @@ +/* + * Copyright 2024 Yelp Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction; + +import com.google.type.LatLng; +import com.yelp.nrtsearch.server.grpc.MultiFunctionScoreQuery; +import com.yelp.nrtsearch.server.luceneserver.IndexState; +import com.yelp.nrtsearch.server.luceneserver.doc.LoadedDocValues; +import com.yelp.nrtsearch.server.luceneserver.doc.SegmentDocLookup; +import com.yelp.nrtsearch.server.luceneserver.field.FieldDef; +import com.yelp.nrtsearch.server.luceneserver.field.LatLonFieldDef; +import com.yelp.nrtsearch.server.luceneserver.geo.GeoPoint; +import com.yelp.nrtsearch.server.luceneserver.geo.GeoUtils; +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import org.apache.lucene.index.*; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; + +public class GeoPointDecayFilterFunction extends DecayFilterFunction { + + private final MultiFunctionScoreQuery.DecayFunction decayFunction; + private final String fieldName; + private final DecayFunction decayType; + private final double scale; + private final double offset; + private final double decay; + private final LatLng origin; + private final IndexState indexState; + + /** + * Constructor. + * + * @param filterQuery filter to use when applying this function, or null if none + * @param weight weight multiple to scale the function score + * @param decayFunction to score a document with a function that decays depending on the distance + * between an origin point and a geoPoint doc field value + * @param indexState indexState for validation and doc value lookup + */ + public GeoPointDecayFilterFunction( + Query filterQuery, + float weight, + MultiFunctionScoreQuery.DecayFunction decayFunction, + IndexState indexState) { + super(filterQuery, weight, decayFunction); + this.decayFunction = decayFunction; + this.fieldName = decayFunction.getFieldName(); + this.decayType = getDecayType(decayFunction.getDecayType()); + this.origin = decayFunction.getGeoPoint(); + this.decay = decayFunction.getDecay(); + double userGivenScale = GeoUtils.getDistance(decayFunction.getScale()); + this.scale = decayType.computeScale(userGivenScale, decay); + this.offset = + !decayFunction.getOffset().isEmpty() + ? GeoUtils.getDistance(decayFunction.getOffset()) + : 0.0; + this.indexState = indexState; + validateLatLonField(indexState.getField(fieldName)); + } + + public void validateLatLonField(FieldDef fieldDef) { + if (!(fieldDef instanceof LatLonFieldDef)) { + throw new IllegalArgumentException( + fieldName + + " should be a LAT_LON to apply geoPoint decay function but it is: " + + fieldDef.getType()); + } + LatLonFieldDef latLonFieldDef = (LatLonFieldDef) fieldDef; + // TODO: Add support for multi-value fields + if (latLonFieldDef.isMultiValue()) { + throw new IllegalArgumentException( + "Multivalued fields are not supported for decay functions yet"); + } + if (!latLonFieldDef.hasDocValues()) { + throw new IllegalStateException("No doc values present for LAT_LON field: " + fieldName); + } + } + + @Override + public LeafFunction getLeafFunction(LeafReaderContext leafContext) { + return new GeoPointDecayLeafFunction(leafContext); + } + + public final class GeoPointDecayLeafFunction implements LeafFunction { + + SegmentDocLookup segmentDocLookup; + + public GeoPointDecayLeafFunction(LeafReaderContext context) { + segmentDocLookup = indexState.docLookup.getSegmentLookup(context); + } + + @Override + public double score(int docId, float innerQueryScore) throws IOException { + segmentDocLookup.setDocId(docId); + LoadedDocValues geoPointLoadedDocValues = + (LoadedDocValues) segmentDocLookup.get(fieldName); + if (geoPointLoadedDocValues.isEmpty()) { + return 0.0; + } else { + GeoPoint latLng = geoPointLoadedDocValues.get(0); + double distance = + GeoUtils.arcDistance( + origin.getLatitude(), origin.getLongitude(), latLng.getLat(), latLng.getLon()); + double score = decayType.computeScore(distance, offset, scale); + return score * getWeight(); + } + } + + @Override + public Explanation explainScore(int docId, Explanation innerQueryScore) { + double score; + segmentDocLookup.setDocId(docId); + LoadedDocValues geoPointLoadedDocValues = + (LoadedDocValues) segmentDocLookup.get(fieldName); + if (!geoPointLoadedDocValues.isEmpty()) { + GeoPoint latLng = geoPointLoadedDocValues.get(0); + double distance = + GeoUtils.arcDistance( + origin.getLatitude(), origin.getLongitude(), latLng.getLat(), latLng.getLon()); + + Explanation distanceExp = + Explanation.match(distance, "arc distance calculated between two geoPoints"); + + score = decayType.computeScore(distance, offset, scale); + double finalScore = score * getWeight(); + return Explanation.match( + finalScore, + "final score with the provided decay function calculated by score * weight with " + + getWeight() + + " weight value and " + + score + + "score", + List.of(distanceExp, decayType.explainComputeScore(distance, offset, scale))); + } else { + score = 0.0; + return Explanation.match( + score, "score is 0.0 since no doc values were present for " + fieldName); + } + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(super.toString()).append(", decayFunction:"); + sb.append("fieldName: ").append(fieldName); + sb.append("decayType: ").append(decayType); + sb.append("origin: ").append(origin); + sb.append("scale: ").append(scale); + sb.append("offset: ").append(offset); + sb.append("decay: ").append(decay); + return sb.toString(); + } + + @Override + protected FilterFunction doRewrite( + IndexReader reader, boolean filterQueryRewritten, Query rewrittenFilterQuery) { + if (filterQueryRewritten) { + return new GeoPointDecayFilterFunction( + rewrittenFilterQuery, getWeight(), decayFunction, indexState); + } else { + return this; + } + } + + @Override + protected boolean doEquals(FilterFunction other) { + if (other == null) { + return false; + } + if (other.getClass() != this.getClass()) { + return false; + } + GeoPointDecayFilterFunction otherGeoPointDecayFilterFunction = + (GeoPointDecayFilterFunction) other; + return Objects.equals(fieldName, otherGeoPointDecayFilterFunction.fieldName) + && Objects.equals(decayType, otherGeoPointDecayFilterFunction.decayType) + && Objects.equals(origin, otherGeoPointDecayFilterFunction.origin) + && Double.compare(scale, otherGeoPointDecayFilterFunction.scale) == 0 + && Double.compare(offset, otherGeoPointDecayFilterFunction.offset) == 0 + && Double.compare(decay, otherGeoPointDecayFilterFunction.decay) == 0; + } + + @Override + protected int doHashCode() { + return Objects.hash(fieldName, decayType, origin, scale, offset, decay); + } +} diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassianDecayFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassianDecayFunction.java new file mode 100644 index 000000000..f291d8e6b --- /dev/null +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassianDecayFunction.java @@ -0,0 +1,37 @@ +/* + * Copyright 2024 Yelp Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction; + +import org.apache.lucene.search.Explanation; + +public class GuassianDecayFunction implements DecayFunction { + @Override + public double computeScore(double distance, double offset, double scale) { + return Math.exp((-1.0 * Math.pow(Math.max(0.0, distance - offset), 2.0)) / (2.0 * scale)); + } + + @Override + public double computeScale(double scale, double decay) { + return (-1.0 * Math.pow(scale, 2.0)) / (2.0 * Math.log(decay)); + } + + @Override + public Explanation explainComputeScore(double distance, double offset, double scale) { + return Explanation.match( + (float) computeScore(distance, offset, scale), + "exp(- pow(max(0.0, |" + distance + " - " + offset + "), 2.0)/ (2.0 * " + scale + ")"); + } +} diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/LinearDecayFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/LinearDecayFunction.java new file mode 100644 index 000000000..e362cb424 --- /dev/null +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/LinearDecayFunction.java @@ -0,0 +1,37 @@ +/* + * Copyright 2024 Yelp Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction; + +import org.apache.lucene.search.Explanation; + +public class LinearDecayFunction implements DecayFunction { + @Override + public double computeScore(double distance, double offset, double scale) { + return Math.max(0.0, (scale - Math.max(0.0, distance - offset)) / scale); + } + + @Override + public double computeScale(double scale, double decay) { + return scale / (1.0 - decay); + } + + @Override + public Explanation explainComputeScore(double distance, double offset, double scale) { + return Explanation.match( + (float) computeScore(distance, offset, scale), + "max(0.0, (" + scale + " - max(0.0, " + distance + " - " + offset + ")) / " + scale + ")"); + } +} diff --git a/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java b/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java index 6190d793f..ebf5db847 100644 --- a/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java +++ b/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java @@ -16,7 +16,9 @@ package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import com.google.type.LatLng; import com.yelp.nrtsearch.server.grpc.AddDocumentRequest; import com.yelp.nrtsearch.server.grpc.AddDocumentRequest.MultiValuedField; import com.yelp.nrtsearch.server.grpc.FieldDefRequest; @@ -76,6 +78,9 @@ protected void initIndex(String name) throws Exception { .putFields( "text_field", MultiValuedField.newBuilder().addValue("Document2 with term1 filter term").build()) + .putFields( + "lat_lon_field", + MultiValuedField.newBuilder().addValue("41.8781").addValue("-87.6298").build()) .build(); docs.add(request); request = @@ -86,6 +91,9 @@ protected void initIndex(String name) throws Exception { .putFields( "text_field", MultiValuedField.newBuilder().addValue("Document1 with term2 filter term").build()) + .putFields( + "lat_lon_field", + MultiValuedField.newBuilder().addValue("51.5074").addValue("-0.1278").build()) .build(); docs.add(request); request = @@ -98,6 +106,9 @@ protected void initIndex(String name) throws Exception { MultiValuedField.newBuilder() .addValue("Document2 with both term1 and term2 filter terms") .build()) + .putFields( + "lat_lon_field", + MultiValuedField.newBuilder().addValue("45.5051").addValue("-122.6750").build()) .build(); docs.add(request); addDocuments(docs.stream()); @@ -553,6 +564,172 @@ public void testScriptWithScore() { verifyResponseHits(response, List.of(2, 4), List.of(0.5030323266983032, 0.5073584914207458)); } + @Test + public void testExpDecayFunctionGeoPointWithWeight() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder().setField("text_field").setQuery("Document2").build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.99f) + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_EXPONENTIAL) + .setOffset("0 m") + .setScale("1 km") + .setGeoPoint(latLng) + .setFieldName("lat_lon_field") + .build()) + .setWeight(0.7f) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + verifyResponseHitsWithDelta( + response, List.of(2, 4), List.of(2.3963971216289792E-6, 2.034676950471705E-18), 0.00000001); + } + + @Test + public void testExpDecayFunctionGeoPoint() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder().setField("text_field").setQuery("Document2").build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.99f) + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_EXPONENTIAL) + .setOffset("0 m") + .setScale("1 km") + .setGeoPoint(latLng) + .setFieldName("lat_lon_field") + .build()) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + verifyResponseHitsWithDelta( + response, List.of(2, 4), List.of(3.4234246868436458E-6, 2.034676950471705E-18), 0.0); + } + + @Test + public void testExpDecayFunctionNoDocValue() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder().setField("text_field").setQuery("none").build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.99f) + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_EXPONENTIAL) + .setOffset("0 m") + .setScale("1 km") + .setGeoPoint(latLng) + .setFieldName("lat_lon_field") + .build()) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + verifyResponseHitsWithDelta(response, List.of(1), List.of(0.0), 0.0); + } + + @Test + public void testLinearDecayFunctionGeoPoint() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder().setField("text_field").setQuery("Document2").build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.2f) + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_LINEAR) + .setOffset("100 km") + .setScale("6000 km") + .setGeoPoint(latLng) + .setFieldName("lat_lon_field") + .build()) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + verifyResponseHitsWithDelta(response, List.of(2, 4), List.of(0.2910, 0.1358), 0.0001); + } + + @Test + public void testGuassDecayFunctionGeoPoint() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder().setField("text_field").setQuery("Document2").build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.5f) + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_GUASSIAN) + .setOffset("10000 km") + .setScale("100 km") + .setGeoPoint(latLng) + .setFieldName("lat_lon_field") + .build()) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + verifyResponseHitsWithDelta(response, List.of(2, 4), List.of(0.3381, 0.2772), 0.0001); + } + + @Test + public void testInvalidGeoPointField() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + assertThrows( + Exception.class, + () -> { + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder() + .setField("text_field") + .setQuery("Document2") + .build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.99f) + .setDecayType( + MultiFunctionScoreQuery.DecayType.DECAY_TYPE_EXPONENTIAL) + .setOffset("0 m") + .setScale("1 km") + .setGeoPoint(latLng) + .setFieldName("text_field") + .build()) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + }); + } + @Test public void testMultiMatchAll_multiply_multiply() { multiFunctionAndVerify( @@ -890,6 +1067,11 @@ private void multiFunctionAndVerify( private void verifyResponseHits( SearchResponse searchResponse, List ids, List scores) { + verifyResponseHitsWithDelta(searchResponse, ids, scores, 0.00001); + } + + private void verifyResponseHitsWithDelta( + SearchResponse searchResponse, List ids, List scores, Double delta) { assertEquals(ids.size(), scores.size()); Map scoresMap = new HashMap<>(); for (int i = 0; i < ids.size(); ++i) { @@ -903,7 +1085,7 @@ private void verifyResponseHits( hit.getFieldsOrThrow("doc_id").getFieldValue(0).getTextValue(), hit.getScore()); } for (Map.Entry entry : scoresMap.entrySet()) { - assertEquals(entry.getValue(), responseScoresMap.get(entry.getKey()), 0.00001); + assertEquals(entry.getValue(), responseScoresMap.get(entry.getKey()), delta); } } diff --git a/src/test/resources/search/query/multifunction/registerFieldsMFSQ.json b/src/test/resources/search/query/multifunction/registerFieldsMFSQ.json index feba3bc08..8b917bc21 100644 --- a/src/test/resources/search/query/multifunction/registerFieldsMFSQ.json +++ b/src/test/resources/search/query/multifunction/registerFieldsMFSQ.json @@ -20,6 +20,12 @@ "name": "double_field", "type": "DOUBLE", "storeDocValues": true + }, + { + "name": "lat_lon_field", + "type": "LAT_LON", + "storeDocValues": true, + "search": true } ] }