Skip to content

Commit

Permalink
Add decay function support for MultiFunctionScoreQuery (#641)
Browse files Browse the repository at this point in the history
* Add decay function support for MultiFunctionScoreQuery
  • Loading branch information
swethakann authored Apr 3, 2024
1 parent 915da1a commit 12d836e
Show file tree
Hide file tree
Showing 12 changed files with 684 additions and 2 deletions.
30 changes: 30 additions & 0 deletions clientlib/src/main/proto/yelp/nrtsearch/search.proto
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,39 @@ message MultiFunctionScoreQuery {
oneof Function {
// Produce score with score script definition
Script script = 3;
// Produce score with a decay function
DecayFunction decayFunction = 4;
}
}

// Apply decay function to docs
message DecayFunction {
// Document field name to use
string fieldName = 1;
// Type of decay function to apply
DecayType decayType = 2;
// Origin point to calculate the distance
oneof Origin {
google.type.LatLng geoPoint = 3;
}
// Currently only distance based scale and offset units are supported
// Distance from origin + offset at which computed score will be equal to decay. Scale should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", "15 km", "5 m", "7 mi"
string scale = 4;
// Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set. Offset should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", "15 km", "5 m", "7 mi"
string offset = 5;
// Defines decay rate for scoring. Should be between (0, 1)
float decay = 6;
}

enum DecayType {
// Exponential decay function
DECAY_TYPE_EXPONENTIAL = 0;
// Linear decay function
DECAY_TYPE_LINEAR = 1;
// Gaussian decay function
DECAY_TYPE_GUASSIAN = 2;
}

// How to combine multiple function scores to produce a final function score
enum FunctionScoreMode {
// Multiply weighted function scores together
Expand Down
33 changes: 32 additions & 1 deletion docs/queries/multi_function_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,40 @@ Proto definition:
oneof Function {
// Produce score with score script definition
Script script = 3;
// Produce score with a decay function
DecayFunction decayFunction = 4;
}
}
// Apply decay function to docs
message DecayFunction {
// Document field name to use
string fieldName = 1;
// Type of decay function to apply
DecayType decayType = 2;
// Origin point to calculate the distance
oneof Origin {
google.type.LatLng geoPoint = 3;
}
// Currently only distance based scale and offset units are supported
// Distance from origin + offset at which computed score will be equal to decay. Scale should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", 15 km", "5 m", "7 mi"
string scale = 4;
// Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set. Offset should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", 15 km", "5 m", "7 mi"
string offset = 5;
// Defines decay rate for scoring. Should be between (0, 1)
float decay = 6;
}
enum DecayType {
// Exponential decay function
DECAY_TYPE_EXPONENTIAL = 0;
// Linear decay function
DECAY_TYPE_LINEAR = 1;
// Gaussian decay function
DECAY_TYPE_GUASSIAN = 2;
}
// How to combine multiple function scores to produce a final function score
enum FunctionScoreMode {
// Multiply weighted function scores together
Expand Down Expand Up @@ -55,4 +86,4 @@ Proto definition:
float min_score = 5;
// Determine minimal score is excluded or not. By default, it's false;
bool min_excluded = 6;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.yelp.nrtsearch.server.luceneserver.geo;

import org.apache.lucene.util.SloppyMath;

public class GeoUtils {

private static final double KM_TO_M = 1000.0;
Expand Down Expand Up @@ -64,4 +66,12 @@ public static double convertDistanceToADifferentUnit(double distanceNumber, Stri
throw new IllegalArgumentException("Invalid unit " + unit);
}
}

/**
* Return the distance (in meters) between 2 lat,lon geo points using the haversine method
* implemented by lucene
*/
public static double arcDistance(double lat1, double lon1, double lat2, double lon2) {
return SloppyMath.haversinMeters(lat1, lon1, lat2, lon2);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction;

import com.yelp.nrtsearch.server.grpc.MultiFunctionScoreQuery;
import org.apache.lucene.search.Query;

public abstract class DecayFilterFunction extends FilterFunction {

/**
* Constructor.
*
* @param filterQuery filter to use when applying this function, or null if none
* @param weight weight multiple to scale the function score
* @param decayFunction to score a document with a function that decays depending on the distance
* between an origin point and a numeric doc field value
*/
public DecayFilterFunction(
Query filterQuery, float weight, MultiFunctionScoreQuery.DecayFunction decayFunction) {
super(filterQuery, weight);
if (decayFunction.getDecay() <= 0 || decayFunction.getDecay() >= 1) {
throw new IllegalArgumentException(
"decay rate should be between (0, 1) but is " + decayFunction.getDecay());
}
}

protected DecayFunction getDecayType(MultiFunctionScoreQuery.DecayType decayType) {
return switch (decayType) {
case DECAY_TYPE_GUASSIAN -> new GuassianDecayFunction();
case DECAY_TYPE_EXPONENTIAL -> new ExponentialDecayFunction();
case DECAY_TYPE_LINEAR -> new LinearDecayFunction();
default -> throw new IllegalArgumentException(
decayType
+ " not supported. Only exponential, guassian and linear decay functions are supported");
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction;

import org.apache.lucene.search.Explanation;

public interface DecayFunction {
/**
* Computes the decayed score based on the provided distance, offset, and scale.
*
* @param distance the distance from a given origin point.
* @param offset the point after which the decay starts.
* @param scale scale factor that influences the rate of decay. This scale value is computed from
* the user given scale using the computeScale() method.
* @return the decayed score after applying the decay function
*/
double computeScore(double distance, double offset, double scale);

/**
* Computes the adjusted scale based on a user given scale and decay rate.
*
* @param scale user given scale.
* @param decay decay rate that decides how the score decreases.
* @return adjusted scale which will be used by the computeScore() method.
*/
double computeScale(double scale, double decay);

/**
* Provides an explanation for the computed score based on the given distance, offset, and scale.
*
* @param distance the distance from a given origin point.
* @param offset the point after which the decay starts.
* @param scale scale factor that influences the rate of decay. This scale value is computed from
* the user given scale using the computeScale() method.
* @return Explanation object that details the calculations involved in computing the decayed
* score.
*/
Explanation explainComputeScore(double distance, double offset, double scale);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction;

import org.apache.lucene.search.Explanation;

public class ExponentialDecayFunction implements DecayFunction {
@Override
public double computeScore(double distance, double offset, double scale) {
return Math.exp(scale * Math.max(0.0, distance - offset));
}

@Override
public double computeScale(double scale, double decay) {
return Math.log(decay) / scale;
}

@Override
public Explanation explainComputeScore(double distance, double offset, double scale) {
return Explanation.match(
(float) computeScore(distance, offset, scale),
"exp(" + scale + " * max(0.0, " + distance + " - " + offset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ public static FilterFunction build(
indexState.docLookup);
return new ScriptFilterFunction(
filterQuery, weight, filterFunctionGrpc.getScript(), scriptSource);
case DECAYFUNCTION:
MultiFunctionScoreQuery.DecayFunction decayFunction = filterFunctionGrpc.getDecayFunction();
if (decayFunction.hasGeoPoint()) {
return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction, indexState);
} else {
throw new IllegalArgumentException(
"Decay Function should contain a geoPoint for Origin field");
}
case FUNCTION_NOT_SET:
return new WeightFilterFunction(filterQuery, weight);
default:
Expand Down
Loading

0 comments on commit 12d836e

Please sign in to comment.