Skip to content

Commit

Permalink
Adding queryable encryption range support
Browse files Browse the repository at this point in the history
Supports range style queries for encrypted fields
  • Loading branch information
rozza committed Jan 16, 2025
1 parent 14985a9 commit 671d324
Show file tree
Hide file tree
Showing 15 changed files with 1,148 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Optional;
import java.util.function.Function;

import org.bson.conversions.Bson;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.query.Collation;
import org.springframework.data.mongodb.core.schema.MongoJsonSchema;
Expand Down Expand Up @@ -51,10 +52,11 @@ public class CollectionOptions {
private ValidationOptions validationOptions;
private @Nullable TimeSeriesOptions timeSeriesOptions;
private @Nullable CollectionChangeStreamOptions changeStreamOptions;
private @Nullable Bson encryptedFields;

private CollectionOptions(@Nullable Long size, @Nullable Long maxDocuments, @Nullable Boolean capped,
@Nullable Collation collation, ValidationOptions validationOptions, @Nullable TimeSeriesOptions timeSeriesOptions,
@Nullable CollectionChangeStreamOptions changeStreamOptions) {
@Nullable CollectionChangeStreamOptions changeStreamOptions, @Nullable Bson encryptedFields) {

this.maxDocuments = maxDocuments;
this.size = size;
Expand All @@ -63,6 +65,7 @@ private CollectionOptions(@Nullable Long size, @Nullable Long maxDocuments, @Nul
this.validationOptions = validationOptions;
this.timeSeriesOptions = timeSeriesOptions;
this.changeStreamOptions = changeStreamOptions;
this.encryptedFields = encryptedFields;
}

/**
Expand All @@ -76,7 +79,7 @@ public static CollectionOptions just(Collation collation) {

Assert.notNull(collation, "Collation must not be null");

return new CollectionOptions(null, null, null, collation, ValidationOptions.none(), null, null);
return new CollectionOptions(null, null, null, collation, ValidationOptions.none(), null, null, null);
}

/**
Expand All @@ -86,7 +89,7 @@ public static CollectionOptions just(Collation collation) {
* @since 2.0
*/
public static CollectionOptions empty() {
return new CollectionOptions(null, null, null, null, ValidationOptions.none(), null, null);
return new CollectionOptions(null, null, null, null, ValidationOptions.none(), null, null, null);
}

/**
Expand Down Expand Up @@ -136,7 +139,7 @@ public static CollectionOptions emitChangedRevisions() {
*/
public CollectionOptions capped() {
return new CollectionOptions(size, maxDocuments, true, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
Expand All @@ -148,7 +151,7 @@ public CollectionOptions capped() {
*/
public CollectionOptions maxDocuments(long maxDocuments) {
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
Expand All @@ -160,7 +163,7 @@ public CollectionOptions maxDocuments(long maxDocuments) {
*/
public CollectionOptions size(long size) {
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
Expand All @@ -172,7 +175,7 @@ public CollectionOptions size(long size) {
*/
public CollectionOptions collation(@Nullable Collation collation) {
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
Expand Down Expand Up @@ -293,7 +296,7 @@ public CollectionOptions validation(ValidationOptions validationOptions) {

Assert.notNull(validationOptions, "ValidationOptions must not be null");
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
Expand All @@ -307,7 +310,7 @@ public CollectionOptions timeSeries(TimeSeriesOptions timeSeriesOptions) {

Assert.notNull(timeSeriesOptions, "TimeSeriesOptions must not be null");
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
Expand All @@ -321,7 +324,19 @@ public CollectionOptions changeStream(CollectionChangeStreamOptions changeStream

Assert.notNull(changeStreamOptions, "ChangeStreamOptions must not be null");
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
* Create new {@link CollectionOptions} with the given {@code encryptedFields}.
*
* @param encryptedFields can be null
* @return new instance of {@link CollectionOptions}.
* @since QERange
*/
public CollectionOptions encryptedFields(@Nullable Bson encryptedFields) {
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions, encryptedFields);
}

/**
Expand Down Expand Up @@ -392,12 +407,22 @@ public Optional<CollectionChangeStreamOptions> getChangeStreamOptions() {
return Optional.ofNullable(changeStreamOptions);
}

/**
* Get the {@code encryptedFields} if available.
*
* @return {@link Optional#empty()} if not specified.
* @since QERange
*/
public Optional<Bson> getEncryptedFields() {
return Optional.ofNullable(encryptedFields);
}

@Override
public String toString() {
return "CollectionOptions{" + "maxDocuments=" + maxDocuments + ", size=" + size + ", capped=" + capped
+ ", collation=" + collation + ", validationOptions=" + validationOptions + ", timeSeriesOptions="
+ timeSeriesOptions + ", changeStreamOptions=" + changeStreamOptions + ", disableValidation="
+ disableValidation() + ", strictValidation=" + strictValidation() + ", moderateValidation="
+ timeSeriesOptions + ", changeStreamOptions=" + changeStreamOptions + ", encryptedFields=" + encryptedFields
+ ", disableValidation=" + disableValidation() + ", strictValidation=" + strictValidation() + ", moderateValidation="
+ moderateValidation() + ", warnOnValidationError=" + warnOnValidationError() + ", failOnValidationError="
+ failOnValidationError() + '}';
}
Expand Down Expand Up @@ -431,7 +456,10 @@ public boolean equals(@Nullable Object o) {
if (!ObjectUtils.nullSafeEquals(timeSeriesOptions, that.timeSeriesOptions)) {
return false;
}
return ObjectUtils.nullSafeEquals(changeStreamOptions, that.changeStreamOptions);
if (!ObjectUtils.nullSafeEquals(changeStreamOptions, that.changeStreamOptions)) {
return false;
}
return ObjectUtils.nullSafeEquals(encryptedFields, that.encryptedFields);
}

@Override
Expand All @@ -443,6 +471,7 @@ public int hashCode() {
result = 31 * result + ObjectUtils.nullSafeHashCode(validationOptions);
result = 31 * result + ObjectUtils.nullSafeHashCode(timeSeriesOptions);
result = 31 * result + ObjectUtils.nullSafeHashCode(changeStreamOptions);
result = 31 * result + ObjectUtils.nullSafeHashCode(encryptedFields);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ public final class EncryptionAlgorithms {
public static final String AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic";
public static final String AEAD_AES_256_CBC_HMAC_SHA_512_Random = "AEAD_AES_256_CBC_HMAC_SHA_512-Random";

public static final String RANGE = "Range";

}
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ public CreateCollectionOptions convertToCreateCollectionOptions(@Nullable Collec
collectionOptions.getChangeStreamOptions().ifPresent(it -> result
.changeStreamPreAndPostImagesOptions(new ChangeStreamPreAndPostImagesOptions(it.getPreAndPostImages())));

collectionOptions.getEncryptedFields().ifPresent(result::encryptedFields);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2172,8 +2172,9 @@ protected <O> AggregationResults<O> doAggregate(Aggregation aggregation, String

List<Document> pipeline = aggregationUtil.createPipeline(aggregation, context);

if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
// TODO revert to DEBUG
if (LOGGER.isErrorEnabled()) {
LOGGER.error(
String.format("Executing aggregation: %s in collection %s", serializeToJsonSafely(pipeline), collectionName));
}

Expand Down Expand Up @@ -2594,10 +2595,10 @@ protected <S, T> List<T> doFind(String collectionName,
Document mappedFields = queryContext.getMappedFields(entity, EntityProjection.nonProjecting(entityClass));
Document mappedQuery = queryContext.getMappedQuery(entity);

if (LOGGER.isDebugEnabled()) {

// TODO revert to DEBUG
if (LOGGER.isErrorEnabled()) {
Document mappedSort = getMappedSortObject(query, entityClass);
LOGGER.debug(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s",
LOGGER.error(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s",
serializeToJsonSafely(mappedQuery), mappedFields, serializeToJsonSafely(mappedSort), entityClass,
collectionName));
}
Expand All @@ -2623,8 +2624,9 @@ <S, T> List<T> doFind(CollectionPreparer<MongoCollection<Document>> collectionPr
Document mappedQuery = queryContext.getMappedQuery(entity);
Document mappedSort = getMappedSortObject(query, sourceClass);

if (LOGGER.isDebugEnabled()) {
LOGGER.debug(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s",
// TODO revert to DEBUG
if (LOGGER.isErrorEnabled()) {
LOGGER.error(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s",
serializeToJsonSafely(mappedQuery), mappedFields, serializeToJsonSafely(mappedSort), sourceClass,
collectionName));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.springframework.data.mapping.model.PropertyValueProvider;
import org.springframework.data.mapping.model.SpELContext;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;

import org.springframework.data.util.TypeInformation;
import org.springframework.lang.Nullable;

Expand All @@ -33,24 +34,39 @@
public class MongoConversionContext implements ValueConversionContext<MongoPersistentProperty> {

private final PropertyValueProvider<MongoPersistentProperty> accessor; // TODO: generics
private final @Nullable MongoPersistentProperty persistentProperty;
private final MongoConverter mongoConverter;

@Nullable private final MongoPersistentProperty persistentProperty;
@Nullable private final SpELContext spELContext;
@Nullable private final String queryFieldPath;

public MongoConversionContext(PropertyValueProvider<MongoPersistentProperty> accessor,
@Nullable MongoPersistentProperty persistentProperty, MongoConverter mongoConverter) {
this(accessor, persistentProperty, mongoConverter, null);
this(accessor, mongoConverter, persistentProperty, null);
}

public MongoConversionContext(PropertyValueProvider<MongoPersistentProperty> accessor,
@Nullable MongoPersistentProperty persistentProperty, MongoConverter mongoConverter,
@Nullable SpELContext spELContext) {
this(accessor, mongoConverter, persistentProperty, spELContext, null);
}

public MongoConversionContext(PropertyValueProvider<MongoPersistentProperty> accessor, MongoConverter mongoConverter,
@Nullable MongoPersistentProperty persistentProperty, @Nullable String queryFieldPath) {
this(accessor, mongoConverter, persistentProperty, null, queryFieldPath);
}

public MongoConversionContext(PropertyValueProvider<MongoPersistentProperty> accessor,
MongoConverter mongoConverter,
@Nullable MongoPersistentProperty persistentProperty,
@Nullable SpELContext spELContext,
@Nullable String queryFieldPath) {

this.accessor = accessor;
this.persistentProperty = persistentProperty;
this.mongoConverter = mongoConverter;
this.spELContext = spELContext;
this.queryFieldPath = queryFieldPath;
}

@Override
Expand Down Expand Up @@ -84,4 +100,9 @@ public <T> T read(@Nullable Object value, TypeInformation<T> target) {
public SpELContext getSpELContext() {
return spELContext;
}

@Nullable
public String getQueryFieldPath() {
return queryFieldPath;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter.NestedDocument;
import org.springframework.data.mongodb.core.mapping.FieldName;
import org.springframework.data.mongodb.core.mapping.MongoField;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty.PropertyToFieldNameConverter;
Expand Down Expand Up @@ -356,9 +357,10 @@ protected Entry<String, Object> getMappedObjectForField(Field field, Object rawV
return createMapEntry(key, getMappedObject(mongoExpression.toDocument(), field.getEntity()));
}

if (isNestedKeyword(rawValue) && !field.isIdField()) {
if (isNestedKeyword(rawValue)) {
Keyword keyword = new Keyword((Document) rawValue);
value = getMappedKeyword(field, keyword);
field = field.with(keyword.getKey());
value = field.isIdField() ? getMappedValue(field, rawValue) : getMappedKeyword(field, keyword);
} else {
value = getMappedValue(field, rawValue);
}
Expand Down Expand Up @@ -455,10 +457,19 @@ protected Document getMappedKeyword(Field property, Keyword keyword) {
@Nullable
@SuppressWarnings("unchecked")
protected Object getMappedValue(Field documentField, Object sourceValue) {

Object value = applyFieldTargetTypeHintToValue(documentField, sourceValue);

if (documentField.getProperty() != null
MongoPersistentProperty property = documentField.getProperty();

String queryPath = property != null && !property.getFieldName().equals(documentField.name) ?
property.getFieldName() + "." + documentField.name : documentField.name;

// TODO add flattened path to convert value and remove logging
if (LOGGER.isErrorEnabled()) {
LOGGER.error(" >-|-> " + queryPath);
}

if (property != null
&& converter.getCustomConversions().hasValueConverter(documentField.getProperty())) {

PropertyValueConverter<Object, Object, ValueConversionContext<MongoPersistentProperty>> valueConverter = converter
Expand Down Expand Up @@ -668,8 +679,17 @@ private Object convertValue(Field documentField, Object sourceValue, Object valu
PropertyValueConverter<Object, Object, ValueConversionContext<MongoPersistentProperty>> valueConverter) {

MongoPersistentProperty property = documentField.getProperty();

String queryPath = property != null && !property.getFieldName().equals(documentField.name) ?
property.getFieldName() + "." + documentField.name : documentField.name;

// TODO add flattened path to convert value and remove logging
if (LOGGER.isErrorEnabled()) {
LOGGER.error(" >--> " + queryPath);
}

MongoConversionContext conversionContext = new MongoConversionContext(NoPropertyPropertyValueProvider.INSTANCE,
property, converter);
converter, property, queryPath);

/* might be an $in clause with multiple entries */
if (property != null && !property.isCollectionLike() && sourceValue instanceof Collection<?> collection) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,10 @@ public <T> T read(@Nullable Object value, TypeInformation<T> target) {
public <T> T write(@Nullable Object value, TypeInformation<T> target) {
return conversionContext.write(value, target);
}

// TODO QE - add to interface
@Nullable
public String getQueryFieldPath() {
return conversionContext.getQueryFieldPath();
}
}
Loading

0 comments on commit 671d324

Please sign in to comment.