Skip to content

Commit

Permalink
ANN: account for nullable columns in search and index creation
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Jan 20, 2025
1 parent fc45986 commit 36c552c
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 27 deletions.
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,5 @@ This is not an officially supported Google product.
Limitations
----------

* Approximate Nearest Neighbors (ANN) strategies are only support for the GoogleSQL dialect
* Approximate Nearest Neighbors (ANN) strategies are only supported for the GoogleSQL dialect
* ANN's `ALTER VECTOR INDEX` is not supported by [Google Cloud Spanner](https://cloud.google.com/spanner/docs/find-approximate-nearest-neighbors#limitations)
27 changes: 24 additions & 3 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class SecondaryIndex:
columns: list[str]
storing_columns: Optional[list[str]] = None
num_leaves: Optional[int] = None # Only necessary for ANN
nullable_column: Optional[bool] = False # Only necessary for ANN
num_branches: Optional[int] = None # Only necessary for ANN
tree_depth: Optional[int] = None # Only necessary for ANN
index_type: Optional[DistanceStrategy] = None # Only necessary for ANN
Expand Down Expand Up @@ -551,7 +552,10 @@ def _generate_secondary_indices_ddl_ANN(
secondary_index_ddl_statements = []

for secondary_index in secondary_indexes:
statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({secondary_index.columns[0]})"
column_name = secondary_index.columns[0]
statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({column_name})"
if secondary_index.nullable_column:
statement += f"\n\tWHERE {column_name} IS NOT NULL"
options_segments = [f"distance_type='{secondary_index.index_type}'"]
if secondary_index.tree_depth > 0:
tree_depth = secondary_index.tree_depth
Expand Down Expand Up @@ -983,6 +987,7 @@ def search_by_ANN(
limit: int = None,
is_embedding_nullable: bool = False,
where_condition: str = None,
column_is_nullable: bool = False,
) -> List[Any]:
sql = SpannerVectorStore._query_ANN(
column_name,
Expand All @@ -995,6 +1000,7 @@ def search_by_ANN(
limit,
is_embedding_nullable,
where_condition,
column_is_nullable=column_is_nullable,
)
staleness = self._query_parameters.staleness
with self._database.snapshot(
Expand All @@ -1017,6 +1023,7 @@ def _query_ANN(
limit: int = None,
is_embedding_nullable: bool = False,
where_condition: str = None,
column_is_nullable: bool = False,
):
"""
Sample query:
Expand All @@ -1026,6 +1033,16 @@ def _query_ANN(
ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding,
options => JSON '{"num_leaves_to_search": 10}')
LIMIT 100
OR
SELECT DocId
FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}
WHERE NullableDocEmbedding IS NOT NULL
ORDER BY APPROX_EUCLIDEAN_DISTANCE(
ARRAY<FLOAT32>[1.0, 2.0, 3.0], NullableDocEmbedding,
options => JSON '{"num_leaves_to_search": 10}')
LIMIT 100
"""

ann_strategy_name = distance_strategy_to_ANN_function.get(strategy, None)
Expand All @@ -1036,8 +1053,12 @@ def _query_ANN(
f"SELECT {column_name} FROM {table_name}"
+ "@{FORCE_INDEX="
+ f"{index_name}"
+ "}\n"
+ f" ORDER BY {ann_strategy_name}(\n"
+ (
"}\n"
if (not column_is_nullable)
else "}\nWHERE " + f"{embedding_column_name} IS NOT NULL\n"
)
+ f"ORDER BY {ann_strategy_name}(\n"
+ f" ARRAY<FLOAT32>{embedding}, {embedding_column_name}, options => JSON '"
+ '{"num_leaves_to_search": %s})\n' % (num_leaves)
)
Expand Down
78 changes: 55 additions & 23 deletions tests/unit/test_vectore_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,28 +93,38 @@ def test_generate_secondary_indices_ddl_ANN(self):
DistanceStrategy.EUCLIDEIAN,
]

nullables = [True, False]
for distance_strategy in strategies:
got = SpannerVectorStore._generate_secondary_indices_ddl_ANN(
"Documents",
secondary_indexes=[
SecondaryIndex(
index_name="DocEmbeddingIndex",
columns=["DocEmbedding"],
num_branches=1000,
tree_depth=3,
index_type=distance_strategy,
num_leaves=100000,
)
],
)

want = [
"CREATE VECTOR INDEX DocEmbeddingIndex\n"
+ " ON Documents(DocEmbedding)\n"
+ f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)"
]

assert canonicalize(got) == canonicalize(want)
for nullable in nullables:
got = SpannerVectorStore._generate_secondary_indices_ddl_ANN(
"Documents",
secondary_indexes=[
SecondaryIndex(
index_name="DocEmbeddingIndex",
columns=["DocEmbedding"],
nullable_column=nullable,
num_branches=1000,
tree_depth=3,
index_type=distance_strategy,
num_leaves=100000,
)
],
)

want = [
"CREATE VECTOR INDEX DocEmbeddingIndex\n"
+ " ON Documents(DocEmbedding)\n"
+ " WHERE DocEmbedding IS NOT NULL\n"
+ f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)"
]
if not nullable:
want = [
"CREATE VECTOR INDEX DocEmbeddingIndex\n"
+ " ON Documents(DocEmbedding)\n"
+ f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)"
]

assert canonicalize(got) == canonicalize(want)

def test_generate_secondary_indices_ddl_ANN_raises_exception_for_non_GoogleSQL_dialect(
self,
Expand Down Expand Up @@ -206,13 +216,35 @@ def test_query_ANN(self):

want = (
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
+ " ORDER BY APPROX_COSINE_DISTANCE(\n"
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
+ ' ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n'
+ "LIMIT 100"
)

print("got", got)
print("want", want)
assert got == want

def test_query_ANN_column_is_nullable(self):
got = SpannerVectorStore._query_ANN(
"DocId",
"Documents",
"DocEmbeddingIndex",
[1.0, 2.0, 3.0],
"DocEmbedding",
10,
DistanceStrategy.COSINE,
limit=100,
column_is_nullable=True,
)

want = (
"SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n"
+ "WHERE DocEmbedding IS NOT NULL\n"
+ "ORDER BY APPROX_COSINE_DISTANCE(\n"
+ ' ARRAY<FLOAT32>[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n'
+ "LIMIT 100"
)

assert got == want


Expand Down

0 comments on commit 36c552c

Please sign in to comment.