Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.x] Vector rescoring oversamples k instead of num_candidates #119887

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Get rescoring scores - hit ordering may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ setup:
k: 3
num_candidates: 3
"rescore_vector":
"num_candidates_factor": 2.0
"oversample": 2.0

# We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard
# We expect the knn search ops + rescoring k * oversample (for rescoring) per shard
- match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 }

# Search with similarity to check number of operations are propagated correctly
Expand All @@ -131,7 +131,7 @@ setup:
num_candidates: 3
similarity: 100000
"rescore_vector":
"num_candidates_factor": 2.0
"oversample": 2.0

# We expect the knn search ops + rescoring num_cnaidates (for rescoring) per shard
# We expect the knn search ops + rescoring k * oversample (for rescoring) per shard
- match: { profile.shards.0.dfs.knn.0.vector_operations_count: 6 }
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Compare scores as hit IDs may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Get rescoring scores - hit ordering may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ setup:
field: vector
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Get rescoring scores - hit ordering may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Get rescoring scores - hit ordering may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Get rescoring scores - hit ordering may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Compare scores as hit IDs may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Get rescoring scores - hit ordering may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Get rescoring scores - hit ordering may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Compare scores as hit IDs may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Compare scores as hit IDs may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ setup:
k: 3
num_candidates: 3
rescore_vector:
num_candidates_factor: 1.5
oversample: 1.5

# Compare scores as hit IDs may change depending on how things are distributed
- match: { hits.total: 3 }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public static boolean isNotUnitVector(float magnitude) {

public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector
public static final int MAGNITUDE_BYTES = 4;
public static final int NUM_CANDS_OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed for k and num_candidates
public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed

private static DenseVectorFieldMapper toType(FieldMapper in) {
return (DenseVectorFieldMapper) in;
Expand Down Expand Up @@ -2019,7 +2019,7 @@ public Query createKnnQuery(
VectorData queryVector,
int k,
int numCands,
Float numCandsFactor,
Float oversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
Expand All @@ -2035,7 +2035,7 @@ public Query createKnnQuery(
queryVector.asFloatVector(),
k,
numCands,
numCandsFactor,
oversample,
filter,
similarityThreshold,
parentFilter
Expand All @@ -2045,7 +2045,11 @@ public Query createKnnQuery(
}

private boolean needsRescore(Float rescoreOversample) {
return rescoreOversample != null && (indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized());
return rescoreOversample != null && isQuantized();
}

private boolean isQuantized() {
return indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized();
}

private Query createKnnBitQuery(
Expand Down Expand Up @@ -2101,7 +2105,7 @@ private Query createKnnFloatQuery(
float[] queryVector,
int k,
int numCands,
Float numCandsFactor,
Float oversample,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
Expand All @@ -2122,18 +2126,17 @@ && isNotUnitVector(squaredMagnitude)) {
}
}

Integer adjustedK = k;
int adjustedNumCands = numCands;
if (needsRescore(numCandsFactor)) {
// Get all candidates, get top k as part of rescoring
adjustedK = null;
// numCands * numCandsFactor <= NUM_CANDS_OVERSAMPLE_LIMIT. Adjust otherwise.
adjustedNumCands = Math.min((int) Math.ceil(numCands * numCandsFactor), NUM_CANDS_OVERSAMPLE_LIMIT);
int adjustedK = k;
boolean rescore = needsRescore(oversample);
if (rescore) {
// Will get k * oversample for rescoring, and get the top k
adjustedK = Math.min((int) Math.ceil(k * oversample), OVERSAMPLE_LIMIT);
numCands = Math.max(adjustedK, numCands);
}
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, adjustedNumCands, parentFilter)
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, adjustedNumCands, filter);
if (needsRescore(numCandsFactor)) {
? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, adjustedK, numCands, parentFilter)
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter);
if (rescore) {
knnQuery = new RescoreKnnVectorQuery(
name(),
queryVector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {

DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
String parentPath = context.nestedLookup().getNestedParent(fieldName);
Float numCandidatesFactor = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.numCandidatesFactor();
Float oversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample();

BitSetProducer parentBitSet = null;
if (parentPath != null) {
Expand Down Expand Up @@ -557,15 +557,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
}
}

return vectorFieldType.createKnnQuery(
queryVector,
k,
adjustedNumCands,
numCandidatesFactor,
filterQuery,
vectorSimilarity,
parentBitSet
);
return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, oversample, filterQuery, vectorSimilarity, parentBitSet);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,41 @@

public class RescoreVectorBuilder implements Writeable, ToXContentObject {

public static final ParseField NUM_CANDIDATES_FACTOR_FIELD = new ParseField("num_candidates_factor");
public static final ParseField OVERSAMPLE_FIELD = new ParseField("oversample");
public static final float MIN_OVERSAMPLE = 1.0F;
private static final ConstructingObjectParser<RescoreVectorBuilder, Void> PARSER = new ConstructingObjectParser<>(
"rescore_vector",
args -> new RescoreVectorBuilder((Float) args[0])
);

static {
PARSER.declareFloat(ConstructingObjectParser.constructorArg(), NUM_CANDIDATES_FACTOR_FIELD);
PARSER.declareFloat(ConstructingObjectParser.constructorArg(), OVERSAMPLE_FIELD);
}

// Oversample is required as of now as it is the only field in the rescore vector
private final float numCandidatesFactor;
private final float oversample;

public RescoreVectorBuilder(float numCandidatesFactor) {
Objects.requireNonNull(numCandidatesFactor, "[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be set");
Objects.requireNonNull(numCandidatesFactor, "[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be set");
if (numCandidatesFactor < MIN_OVERSAMPLE) {
throw new IllegalArgumentException("[" + NUM_CANDIDATES_FACTOR_FIELD.getPreferredName() + "] must be >= " + MIN_OVERSAMPLE);
throw new IllegalArgumentException("[" + OVERSAMPLE_FIELD.getPreferredName() + "] must be >= " + MIN_OVERSAMPLE);
}
this.numCandidatesFactor = numCandidatesFactor;
this.oversample = numCandidatesFactor;
}

public RescoreVectorBuilder(StreamInput in) throws IOException {
this.numCandidatesFactor = in.readFloat();
this.oversample = in.readFloat();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeFloat(numCandidatesFactor);
out.writeFloat(oversample);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(NUM_CANDIDATES_FACTOR_FIELD.getPreferredName(), numCandidatesFactor);
builder.field(OVERSAMPLE_FIELD.getPreferredName(), oversample);
builder.endObject();
return builder;
}
Expand All @@ -71,15 +71,15 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RescoreVectorBuilder that = (RescoreVectorBuilder) o;
return Objects.equals(numCandidatesFactor, that.numCandidatesFactor);
return Objects.equals(oversample, that.oversample);
}

@Override
public int hashCode() {
return Objects.hashCode(numCandidatesFactor);
return Objects.hashCode(oversample);
}

public float numCandidatesFactor() {
return numCandidatesFactor;
public float oversample() {
return oversample;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.BBQ_MIN_DIMS;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.FLOAT;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -456,20 +457,18 @@ public void testRescoreOversampleModifiesNumCandidates() {
);

// Total results is k, internal k is multiplied by oversample
checkRescoreQueryParameters(fieldType, 10, 200, randomInt(), 2.5F, null, 500, 10);
checkRescoreQueryParameters(fieldType, 10, 200, 2.5F, 25, 200, 10);
// If numCands < k, update numCands to k
checkRescoreQueryParameters(fieldType, 10, 20, randomInt(), 2.5F, null, 50, 10);
// Oversampling limits for num candidates
checkRescoreQueryParameters(fieldType, 1000, 1000, randomInt(), 11.0F, null, 10000, 1000);
checkRescoreQueryParameters(fieldType, 5000, 7500, randomInt(), 2.5F, null, 10000, 5000);
checkRescoreQueryParameters(fieldType, 10, 20, 2.5F, 25, 25, 10);
// Oversampling limits for k
checkRescoreQueryParameters(fieldType, 1000, 1000, 11.0F, OVERSAMPLE_LIMIT, OVERSAMPLE_LIMIT, 1000);
}

private static void checkRescoreQueryParameters(
DenseVectorFieldType fieldType,
int k,
int candidates,
int requestSize,
float numCandsFactor,
float oversample,
Integer expectedK,
int expectedCandidates,
int expectedResults
Expand All @@ -478,7 +477,7 @@ private static void checkRescoreQueryParameters(
VectorData.fromFloats(new float[] { 1, 4, 10 }),
k,
candidates,
numCandsFactor,
oversample,
null,
null,
null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.NUM_CANDS_OVERSAMPLE_LIMIT;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT;
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
Expand Down Expand Up @@ -175,8 +176,13 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity()));
query = ((VectorSimilarityQuery) query).getInnerKnnQuery();
}
Integer k = queryBuilder.k();
if (k == null) {
k = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize();
}
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
assertEquals(k.intValue(), (rescoreQuery.k()));
query = rescoreQuery.innerQuery();
}
switch (elementType()) {
Expand All @@ -190,14 +196,11 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
}
BooleanQuery booleanQuery = builder.build();
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
// The field should always be resolved to the concrete field
Integer k = queryBuilder.k();
Integer numCands = queryBuilder.numCands();
if (queryBuilder.rescoreVectorBuilder() != null && isQuantizedElementType()) {
Float numCandsFactor = queryBuilder.rescoreVectorBuilder().numCandidatesFactor();
int minCands = k == null ? 1 : k;
numCands = Math.max(minCands, (int) Math.ceil(numCands * numCandsFactor));
numCands = Math.min(numCands, NUM_CANDS_OVERSAMPLE_LIMIT);
Float oversample = queryBuilder.rescoreVectorBuilder().oversample();
k = Math.min(OVERSAMPLE_LIMIT, (int) Math.ceil(k * oversample));
numCands = Math.max(numCands, k);
}

Query knnVectorQueryBuilt = switch (elementType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public void testInvalidRescoreVectorBuilder() {
IllegalArgumentException.class,
() -> new KnnSearchBuilder("field", randomVector(3), 10, 100, new RescoreVectorBuilder(0.99F), null)
);
assertThat(e.getMessage(), containsString("[num_candidates_factor] must be >= 1.0"));
assertThat(e.getMessage(), containsString("[oversample] must be >= 1.0"));
}

public void testRewrite() throws Exception {
Expand Down