Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Vector Search #4882

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.5.0-SNAPSHOT</version>
<version>4.5.x-GH-4706-SNAPSHOT</version>
<packaging>pom</packaging>

<name>Spring Data MongoDB</name>
Expand All @@ -26,8 +26,8 @@
<properties>
<project.type>multi</project.type>
<dist.id>spring-data-mongodb</dist.id>
<springdata.commons>3.5.0-SNAPSHOT</springdata.commons>
<mongo>5.3.0-beta0</mongo>
<springdata.commons>3.5.0-GH-3193-SNAPSHOT</springdata.commons>
<mongo>5.3.0</mongo>
<mongodb-crypt>${mongo}</mongodb-crypt>
<mongo.reactivestreams>${mongo}</mongo.reactivestreams>
<jmh.version>1.19</jmh.version>
Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-distribution/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.5.0-SNAPSHOT</version>
<version>4.5.x-GH-4706-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
9 changes: 8 additions & 1 deletion spring-data-mongodb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.5.0-SNAPSHOT</version>
<version>4.5.x-GH-4706-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down Expand Up @@ -131,6 +131,13 @@
<optional>true</optional>
</dependency>

<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<version>4.2.2</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.reactivex.rxjava3</groupId>
<artifactId>rxjava</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.List;

import org.bson.Document;

import org.springframework.dao.DataAccessException;
import org.springframework.data.mongodb.MongoDatabaseFactory;
import org.springframework.data.mongodb.UncategorizedMongoDbException;
Expand Down Expand Up @@ -51,11 +50,11 @@ public class DefaultIndexOperations implements IndexOperations {

private static final String PARTIAL_FILTER_EXPRESSION_KEY = "partialFilterExpression";

private final String collectionName;
private final QueryMapper mapper;
private final @Nullable Class<?> type;
protected final String collectionName;
protected final QueryMapper mapper;
protected final @Nullable Class<?> type;

private final MongoOperations mongoOperations;
protected final MongoOperations mongoOperations;

/**
* Creates a new {@link DefaultIndexOperations}.
Expand Down Expand Up @@ -133,7 +132,7 @@ public String ensureIndex(IndexDefinition indexDefinition) {
}

@Nullable
private MongoPersistentEntity<?> lookupPersistentEntity(@Nullable Class<?> entityType, String collection) {
protected MongoPersistentEntity<?> lookupPersistentEntity(@Nullable Class<?> entityType, String collection) {

if (entityType != null) {
return mapper.getMappingContext().getRequiredPersistentEntity(entityType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@
import org.springframework.data.mongodb.core.convert.MongoWriter;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.convert.UpdateMapper;
import org.springframework.data.mongodb.core.index.DefaultSearchIndexOperations;
import org.springframework.data.mongodb.core.index.IndexOperations;
import org.springframework.data.mongodb.core.index.IndexOperationsProvider;
import org.springframework.data.mongodb.core.index.MongoMappingEventPublisher;
import org.springframework.data.mongodb.core.index.MongoPersistentEntityIndexCreator;
import org.springframework.data.mongodb.core.index.SearchIndexOperations;
import org.springframework.data.mongodb.core.index.SearchIndexOperationsProvider;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
Expand Down Expand Up @@ -182,8 +185,8 @@
* @author Michael Krog
* @author Jakub Zurawa
*/
public class MongoTemplate
implements MongoOperations, ApplicationContextAware, IndexOperationsProvider, ReadPreferenceAware {
public class MongoTemplate implements MongoOperations, ApplicationContextAware, IndexOperationsProvider,
SearchIndexOperationsProvider, ReadPreferenceAware {

private static final Log LOGGER = LogFactory.getLog(MongoTemplate.class);
private static final WriteResultChecking DEFAULT_WRITE_RESULT_CHECKING = WriteResultChecking.NONE;
Expand Down Expand Up @@ -768,6 +771,21 @@ public IndexOperations indexOps(Class<?> entityClass) {
return indexOps(getCollectionName(entityClass), entityClass);
}

@Override
public SearchIndexOperations searchIndexOps(String collectionName) {
return searchIndexOps(null, collectionName);
}

@Override
public SearchIndexOperations searchIndexOps(Class<?> type) {
return new DefaultSearchIndexOperations(this, type);
}

@Override
public SearchIndexOperations searchIndexOps(@Nullable Class<?> type, String collectionName) {
return new DefaultSearchIndexOperations(this, collectionName, type);
}

@Override
public BulkOperations bulkOps(BulkMode mode, String collectionName) {
return bulkOps(mode, null, collectionName);
Expand Down Expand Up @@ -1313,7 +1331,7 @@ private WriteConcern potentiallyForceAcknowledgedWrite(@Nullable WriteConcern wc

if (ObjectUtils.nullSafeEquals(WriteResultChecking.EXCEPTION, writeResultChecking)) {
if (wc == null || wc.getWObject() == null
|| (wc.getWObject()instanceof Number concern && concern.intValue() < 1)) {
|| (wc.getWObject() instanceof Number concern && concern.intValue() < 1)) {
return WriteConcern.ACKNOWLEDGED;
}
}
Expand Down Expand Up @@ -1965,7 +1983,8 @@ public <T> List<T> mapReduce(Query query, Class<?> domainType, String inputColle
}

if (mapReduceOptions.getOutputSharded().isPresent()) {
MongoCompatibilityAdapter.mapReduceIterableAdapter(mapReduce).sharded(mapReduceOptions.getOutputSharded().get());
MongoCompatibilityAdapter.mapReduceIterableAdapter(mapReduce)
.sharded(mapReduceOptions.getOutputSharded().get());
}

if (StringUtils.hasText(mapReduceOptions.getOutputCollection()) && !mapReduceOptions.usesInlineOutput()) {
Expand Down Expand Up @@ -2064,7 +2083,7 @@ public <T> List<T> findAllAndRemove(Query query, Class<T> entityClass, String co
}

@Override
public <T> UpdateResult replace(Query query, T replacement, ReplaceOptions options, String collectionName){
public <T> UpdateResult replace(Query query, T replacement, ReplaceOptions options, String collectionName) {

Assert.notNull(replacement, "Replacement must not be null");
return replace(query, (Class<T>) ClassUtils.getUserClass(replacement), replacement, options, collectionName);
Expand Down Expand Up @@ -2740,8 +2759,7 @@ protected <T> T doFindAndModify(CollectionPreparer collectionPreparer, String co
LOGGER.debug(String.format(
"findAndModify using query: %s fields: %s sort: %s for class: %s and update: %s in collection: %s",
serializeToJsonSafely(mappedQuery), fields, serializeToJsonSafely(sort), entityClass,
serializeToJsonSafely(mappedUpdate),
collectionName));
serializeToJsonSafely(mappedUpdate), collectionName));
}

return executeFindOneInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ public static UnwindOperation unwind(String field, String arrayIndex) {
}

/**
* Factory method to create a new {@link UnwindOperation} for the field with the given name, including the name of a new
* field to hold the array index of the element as {@code arrayIndex} using {@code preserveNullAndEmptyArrays}. Note
* that extended unwind is supported in MongoDB version 3.2+.
* Factory method to create a new {@link UnwindOperation} for the field with the given name, including the name of a
* new field to hold the array index of the element as {@code arrayIndex} using {@code preserveNullAndEmptyArrays}.
* Note that extended unwind is supported in MongoDB version 3.2+.
*
* @param field must not be {@literal null} or empty.
* @param arrayIndex must not be {@literal null} or empty.
Expand Down Expand Up @@ -428,6 +428,20 @@ public static StartWithBuilder graphLookup(String fromCollection) {
return GraphLookupOperation.builder().from(fromCollection);
}

/**
* Creates a new {@link VectorSearchOperation} by starting from the {@code indexName} to use.
*
* @param indexName must not be {@literal null} or empty.
* @return new instance of {@link VectorSearchOperation.PathContributor}.
* @since 4.5
*/
public static VectorSearchOperation.PathContributor vectorSearch(String indexName) {

Assert.hasText(indexName, "Index name must not be null or empty");

return VectorSearchOperation.search(indexName);
}

/**
* Factory method to create a new {@link SortOperation} for the given {@link Sort}.
*
Expand Down Expand Up @@ -669,14 +683,14 @@ public static LookupOperation lookup(Field from, Field localField, Field foreign

/**
* Entrypoint for creating {@link LookupOperation $lookup} using a fluent builder API.
*
* <pre class="code">
* Aggregation.lookup().from("restaurants")
* .localField("restaurant_name")
* .foreignField("name")
* .let(newVariable("orders_drink").forField("drink"))
* .pipeline(match(ctx -> new Document("$expr", new Document("$in", List.of("$$orders_drink", "$beverages")))))
* .as("matches")
* Aggregation.lookup().from("restaurants").localField("restaurant_name").foreignField("name")
* .let(newVariable("orders_drink").forField("drink"))
* .pipeline(match(ctx -> new Document("$expr", new Document("$in", List.of("$$orders_drink", "$beverages")))))
* .as("matches")
* </pre>
*
* @return new instance of {@link LookupOperationBuilder}.
* @since 4.1
*/
Expand Down
Loading