From 4e0e2b65b2eef8fbd1b77158b05714f573f2fd0d Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 26 Dec 2024 02:56:46 -0800 Subject: [PATCH 01/13] fix(testing+linting): add nox lint+format directives This change introduces new nox directives: * blacken: `nox -s blacken` * format: `nox -s format` to apply formatting to files * lint: `nox -s lint` to flag linting issues * unit: to run unit tests locally which are the basis to enable scalable development and continuous testing as I prepare to bring in Approximate Nearest Neighors (ANN) functionality into this package. Also while here, fixed a typo in the README.rst file that didn't have the correct import path. --- README.rst | 2 +- noxfile.py | 87 +++++++++++++++++-- src/langchain_google_spanner/graph_qa.py | 1 - src/langchain_google_spanner/graph_store.py | 20 ++--- src/langchain_google_spanner/loader.py | 2 +- .../test_spanner_chat_message_history.py | 2 +- tests/integration/test_spanner_graph_qa.py | 3 +- tests/integration/test_spanner_loader.py | 2 +- .../integration/test_spanner_vector_store.py | 2 +- 9 files changed, 98 insertions(+), 23 deletions(-) diff --git a/README.rst b/README.rst index cb047dd..1c1aba2 100644 --- a/README.rst +++ b/README.rst @@ -73,7 +73,7 @@ Use a vector store to store embedded data and perform vector search. .. code-block:: python - from langchain_google_sapnner import SpannerVectorstore + from langchain_google_spanner import SpannerVectorstore from langchain.embeddings import VertexAIEmbeddings embeddings_service = VertexAIEmbeddings(model_name="textembedding-gecko@003") diff --git a/noxfile.py b/noxfile.py index 2ad8aec..e7d5e3c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -18,25 +18,34 @@ import os import pathlib -import shutil from pathlib import Path -from typing import Optional +import shutil +from typing import List, Optional import nox DEFAULT_PYTHON_VERSION = "3.10" CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() +FLAKE8_VERSION = "flake8==6.1.0" +BLACK_VERSION = "black[jupyter]==23.7.0" +ISORT_VERSION = "isort==5.11.0" +LINT_PATHS = ["src", "tests", "noxfile.py"] + nox.options.sessions = [ - "docs", + "blacken", "docfx", + "docs", + "format", + "lint", + "unit", ] # Error if a python version is missing nox.options.error_on_missing_interpreters = True -@nox.session(python="3.10") +@nox.session(python=DEFAULT_PYTHON_VERSION) def docs(session): """Build the docs for this library.""" @@ -71,7 +80,7 @@ def docs(session): ) -@nox.session(python="3.10") +@nox.session(python=DEFAULT_PYTHON_VERSION) def docfx(session): """Build the docfx yaml files for this library.""" @@ -115,3 +124,71 @@ def docfx(session): os.path.join("docs", ""), os.path.join("docs", "_build", "html", ""), ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint(session): + """Run linters. + + Returns a failure if the linters find linting errors or + sufficiently serious code quality issues. + """ + session.install(FLAKE8_VERSION, BLACK_VERSION) + session.run( + "black", + "--check", + *LINT_PATHS, + ) + session.run("flake8", "google", "tests") + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint_setup_py(session): + """Verify that setup.py is valid (including an RST check).""" + session.install("docutils", "pygments") + session.run("python", "setup.py", "check", "--restructuredtext", "--strict") + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def blacken(session): + session.install(BLACK_VERSION) + session.run( + "black", + *LINT_PATHS, + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def format(session): + session.install(BLACK_VERSION, ISORT_VERSION) + # Sort imports in strict alphabetical order. + session.run( + "isort", + "--fss", + *LINT_PATHS, + ) + session.run( + "black", + *LINT_PATHS, + ) + + +def unit(session): + install_unittest_dependencies(session) + session.run( + "py.test", + "--quiet", + os.path.join("tests", "unit"), + ) + + +UNIT_TEST_STANDARD_DEPENDENCIES = [ + "mock", + "pytest", +] +UNIT_TEST_DEPENDENCIES: List[str] = [] + + +def install_unittest_dependencies(session, *constraints): + standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES + session.install(*standard_deps, *constraints) diff --git a/src/langchain_google_spanner/graph_qa.py b/src/langchain_google_spanner/graph_qa.py index ff399b4..4d72c5b 100644 --- a/src/langchain_google_spanner/graph_qa.py +++ b/src/langchain_google_spanner/graph_qa.py @@ -332,7 +332,6 @@ def _call( inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: - intermediate_steps: List = [] """Generate gql statement, uses it to look up in db and answer question.""" diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index d3e03a0..e6e211a 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -14,9 +14,9 @@ from __future__ import annotations +from abc import abstractmethod import re import string -from abc import abstractmethod from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union from google.cloud import spanner @@ -211,9 +211,9 @@ def from_nodes(name: str, nodes: List[Node]) -> ElementSchema: for k, v in n.properties.items() } ) - node.types[ElementSchema.NODE_KEY_COLUMN_NAME] = ( - TypeUtility.value_to_param_type(nodes[0].id) - ) + node.types[ + ElementSchema.NODE_KEY_COLUMN_NAME + ] = TypeUtility.value_to_param_type(nodes[0].id) return node @staticmethod @@ -264,12 +264,12 @@ def from_edges(name: str, edges: List[Relationship]) -> ElementSchema: for k, v in e.properties.items() } ) - edge.types[ElementSchema.NODE_KEY_COLUMN_NAME] = ( - TypeUtility.value_to_param_type(edges[0].source.id) - ) - edge.types[ElementSchema.TARGET_NODE_KEY_COLUMN_NAME] = ( - TypeUtility.value_to_param_type(edges[0].target.id) - ) + edge.types[ + ElementSchema.NODE_KEY_COLUMN_NAME + ] = TypeUtility.value_to_param_type(edges[0].source.id) + edge.types[ + ElementSchema.TARGET_NODE_KEY_COLUMN_NAME + ] = TypeUtility.value_to_param_type(edges[0].target.id) edge.source = NodeReference( edges[0].source.type, diff --git a/src/langchain_google_spanner/loader.py b/src/langchain_google_spanner/loader.py index d71de11..a1336d2 100644 --- a/src/langchain_google_spanner/loader.py +++ b/src/langchain_google_spanner/loader.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass import datetime import json -from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Union from google.cloud.spanner import Client, KeySet # type: ignore diff --git a/tests/integration/test_spanner_chat_message_history.py b/tests/integration/test_spanner_chat_message_history.py index 397af03..b83d6fb 100644 --- a/tests/integration/test_spanner_chat_message_history.py +++ b/tests/integration/test_spanner_chat_message_history.py @@ -16,10 +16,10 @@ import os import uuid -import pytest # noqa from google.cloud.spanner import Client # type: ignore from langchain_core.messages.ai import AIMessage from langchain_core.messages.human import HumanMessage +import pytest # noqa from langchain_google_spanner import SpannerChatMessageHistory diff --git a/tests/integration/test_spanner_graph_qa.py b/tests/integration/test_spanner_graph_qa.py index 8bac7b8..55e8153 100644 --- a/tests/integration/test_spanner_graph_qa.py +++ b/tests/integration/test_spanner_graph_qa.py @@ -16,12 +16,12 @@ import random import string -import pytest from google.cloud import spanner from langchain.evaluation import load_evaluator from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship from langchain_core.documents import Document from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings +import pytest from langchain_google_spanner.graph_qa import SpannerGraphQAChain from langchain_google_spanner.graph_store import SpannerGraphStore @@ -145,7 +145,6 @@ def load_data(graph: SpannerGraphStore): class TestSpannerGraphQAChain: - @pytest.fixture(scope="module") def setup_db_load_data(self): graph = get_spanner_graph() diff --git a/tests/integration/test_spanner_loader.py b/tests/integration/test_spanner_loader.py index 85f483f..51a8c2f 100644 --- a/tests/integration/test_spanner_loader.py +++ b/tests/integration/test_spanner_loader.py @@ -15,9 +15,9 @@ import os import uuid -import pytest from google.cloud.spanner import Client, KeySet # type: ignore from langchain_core.documents import Document +import pytest from langchain_google_spanner.loader import Column, SpannerDocumentSaver, SpannerLoader diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index 9a19f7b..d3f6cfc 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -16,10 +16,10 @@ import os import uuid -import pytest from google.cloud.spanner import Client # type: ignore from langchain_community.document_loaders import HNLoader from langchain_community.embeddings import FakeEmbeddings +import pytest from langchain_google_spanner.vector_store import ( # type: ignore DistanceStrategy, From 74434ad1c8d3facd3297c2df07f79e6fd4094ac9 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 26 Dec 2024 04:34:34 -0800 Subject: [PATCH 02/13] feat: add Approximate Nearest Neighbor support to distance strategies This change adds ANN distance strategies for GoogleSQL semantics. While here started unit tests to effectively test out components without having to have a running Cloud Spanner instance. Updates #94 --- README.rst | 5 ++ noxfile.py | 11 +++ src/langchain_google_spanner/vector_store.py | 42 +++++++++--- tests/unit/test_vectore_store.py | 71 ++++++++++++++++++++ 4 files changed, 119 insertions(+), 10 deletions(-) create mode 100644 tests/unit/test_vectore_store.py diff --git a/README.rst b/README.rst index 1c1aba2..f09c809 100644 --- a/README.rst +++ b/README.rst @@ -206,3 +206,8 @@ Disclaimer This is not an officially supported Google product. + +Limitations +---------- + +* Approximate Nearest Neighbors (ANN) strategies are only support for the GoogleSQL dialect diff --git a/noxfile.py b/noxfile.py index e7d5e3c..7eaecb3 100644 --- a/noxfile.py +++ b/noxfile.py @@ -173,6 +173,7 @@ def format(session): ) +@nox.session(python=DEFAULT_PYTHON_VERSION) def unit(session): install_unittest_dependencies(session) session.run( @@ -192,3 +193,13 @@ def unit(session): def install_unittest_dependencies(session, *constraints): standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES session.install(*standard_deps, *constraints) + session.run( + "pip", + "install", + "--no-compile", # To ensure no byte recompliation which is usually super slow + "-q", + "--disable-pip-version-check", # Avoid the slow version check + ".", + "-r", + "requirements.txt", + ) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index b5084e8..75c2897 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -14,13 +14,12 @@ from __future__ import annotations -import datetime -import logging from abc import ABC, abstractmethod +import datetime from enum import Enum +import logging from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union -import numpy as np from google.cloud import spanner # type: ignore from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from google.cloud.spanner_v1 import JsonObject, param_types @@ -28,6 +27,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore +import numpy as np from .version import __version__ @@ -104,6 +104,10 @@ class DistanceStrategy(Enum): COSINE = 1 EUCLIDEIAN = 2 + DOT_PRODUCT = 3 + APPROX_DOT_PRODUCT = 4 + APPROX_COSINE = 5 + APPROX_EUCLIDEAN = 6 class DialectSemantics(ABC): @@ -139,16 +143,23 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: ) +_GOOGLE_DISTANCE_ALGO_NAMES = { + DistanceStrategy.APPROX_COSINE: "APPROX_COSINE_DISTANCE", + DistanceStrategy.APPROX_DOT_PRODUCT: "APPROX_DOT_PRODUCT", + DistanceStrategy.APPROX_EUCLIDEAN: "APPROX_EUCLIDEAN_DISTANCE", + DistanceStrategy.COSINE: "COSINE_DISTANCE", + DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", + DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN_DISTANCE", +} + + class GoogleSqlSemnatics(DialectSemantics): """ Implementation of dialect semantics for Google SQL. """ def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: - if distance_strategy == DistanceStrategy.COSINE: - return "COSINE_DISTANCE" - - return "EUCLIDEAN_DISTANCE" + return _GOOGLE_DISTANCE_ALGO_NAMES.get(distance_strategy, "EUCLIDEAN") def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]: where_clause_condition = " AND ".join( @@ -163,15 +174,25 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: return dict(zip(columns, values)) +_PG_DISTANCE_ALGO_NAMES = { + DistanceStrategy.COSINE: "spanner.cosine_distance", + DistanceStrategy.DOT_PRODUCT: "spanner.dot_product", + DistanceStrategy.EUCLIDEIAN: "spanner.euclidean_distance", +} + + class PGSqlSemnatics(DialectSemantics): """ Implementation of dialect semantics for PostgreSQL. """ def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: - if distance_strategy == DistanceStrategy.COSINE: - return "spanner.cosine_distance" - return "spanner.euclidean_distance" + name = _PG_DISTANCE_ALGO_NAMES.get(distance_strategy, None) + if name is None: + raise Exception( + "Unsupported PostgreSQL distance strategy: {}".format(distance_strategy) + ) + return name def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]: where_clause_condition = " AND ".join( @@ -210,6 +231,7 @@ class NearestNeighborsAlgorithm(Enum): """ EXACT_NEAREST_NEIGHBOR = 1 + APPROXIMATE_NEAREST_NEIGHBOR = 2 def __init__( self, diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py new file mode 100644 index 0000000..31f9e3d --- /dev/null +++ b/tests/unit/test_vectore_store.py @@ -0,0 +1,71 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import unittest + +from langchain_google_spanner.vector_store import ( + DistanceStrategy, + GoogleSqlSemnatics, + PGSqlSemnatics, +) + + +class TestGoogleSqlSemnatics(unittest.TestCase): + def test_distance_function_to_string(self): + cases = [ + (DistanceStrategy.COSINE, "COSINE_DISTANCE"), + (DistanceStrategy.DOT_PRODUCT, "DOT_PRODUCT"), + (DistanceStrategy.EUCLIDEIAN, "EUCLIDEAN_DISTANCE"), + (DistanceStrategy.APPROX_COSINE, "APPROX_COSINE_DISTANCE"), + (DistanceStrategy.APPROX_DOT_PRODUCT, "APPROX_DOT_PRODUCT"), + (DistanceStrategy.APPROX_EUCLIDEAN, "APPROX_EUCLIDEAN_DISTANCE"), + ] + + sem = GoogleSqlSemnatics() + got_results = [] + want_results = [] + for strategy, want_str in cases: + got_results.append(sem.getDistanceFunction(strategy)) + want_results.append(want_str) + + assert got_results == want_results + + +class TestPGSqlSemnatics(unittest.TestCase): + def test_distance_function_to_string(self): + cases = [ + (DistanceStrategy.COSINE, "spanner.cosine_distance"), + (DistanceStrategy.DOT_PRODUCT, "spanner.dot_product"), + (DistanceStrategy.EUCLIDEIAN, "spanner.euclidean_distance"), + ] + + sem = PGSqlSemnatics() + got_results = [] + want_results = [] + for strategy, want_str in cases: + got_results.append(sem.getDistanceFunction(strategy)) + want_results.append(want_str) + + assert got_results == want_results + + def test_distance_function_raises_exception_if_unknown(self): + strategies = [ + DistanceStrategy.APPROX_COSINE, + DistanceStrategy.APPROX_DOT_PRODUCT, + DistanceStrategy.APPROX_EUCLIDEAN, + ] + + for strategy in strategies: + with self.assertRaises(Exception): + sem.getDistanceFunction(strategy) From 6a753025ddb870b1c80875a4e4bdec07793ead25 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 26 Dec 2024 06:57:45 -0800 Subject: [PATCH 03/13] Wire up ANN vector index creation --- src/langchain_google_spanner/vector_store.py | 142 ++++++++++++++++--- tests/unit/test_vectore_store.py | 36 +++++ 2 files changed, 161 insertions(+), 17 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 75c2897..080f1a6 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -87,6 +87,10 @@ class SecondaryIndex: index_name: str columns: list[str] storing_columns: Optional[list[str]] = None + num_leaves: Optional[int] = None # Only necessary for ANN + num_branches: Optional[int] = None # Only necessary for ANN + tree_depth: Optional[int] = None # Only necessary for ANN + index_type: Optional[DistanceStrategy] = None # Only necessary for ANN def __post_init__(self): # Check if column_name is None after initialization @@ -109,6 +113,16 @@ class DistanceStrategy(Enum): APPROX_COSINE = 5 APPROX_EUCLIDEAN = 6 + def __str__(self): + return DISTANCE_STRATEGY_STRING[self] + + +DISTANCE_STRATEGY_STRING = { + DistanceStrategy.COSINE: "COSINE", + DistanceStrategy.EUCLIDEIAN: "EUCLIDEIAN", + DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", +} + class DialectSemantics(ABC): """ @@ -152,6 +166,12 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN_DISTANCE", } +_GOOGLE_ALGO_INDEX_NAME = { + DistanceStrategy.COSINE: "COSINE", + DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", + DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN", +} + class GoogleSqlSemnatics(DialectSemantics): """ @@ -173,6 +193,12 @@ def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]: def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: return dict(zip(columns, values)) + def getIndexDistanceType(self, distance_strategy) -> str: + value = _GOOGLE_ALGO_INDEX_NAME.get(distance_strategy, None) + if value is None: + raise Exception(f"{distance_strategy} is unsupported for distance_type") + return value + _PG_DISTANCE_ALGO_NAMES = { DistanceStrategy.COSINE: "spanner.cosine_distance", @@ -276,6 +302,15 @@ def __init__( self.staleness = {key: value} +DEFAULT_ANN_TREE_DEPTH = 2 +ANN_ACCEPTABLE_TREE_DEPTHS = (2, 3) + + +class AlgoKind(Enum): + KNN = 0 + ANN = 1 + + class SpannerVectorStore(VectorStore): GSQL_TYPES = { CONTENT_COLUMN_NAME: ["STRING"], @@ -306,6 +341,7 @@ def init_vector_store_table( primary_key: Optional[str] = None, vector_size: Optional[int] = None, secondary_indexes: Optional[List[SecondaryIndex]] = None, + kind: AlgoKind = None, ) -> bool: """ Initialize the vector store new table in Google Cloud Spanner. @@ -344,6 +380,7 @@ def init_vector_store_table( metadata_columns, primary_key, secondary_indexes, + kind=kind, ) operation = database.update_ddl(ddl) @@ -363,6 +400,7 @@ def _generate_sql( column_configs, primary_key, secondary_indexes: Optional[List[SecondaryIndex]] = None, + kind: Optional[AlgoKind] = AlgoKind.KNN, ): """ Generate SQL for creating the vector store table. @@ -378,6 +416,40 @@ def _generate_sql( Returns: - str: The generated SQL. """ + + ddl_statements = [ + SpannerVectorStore._generate_create_table_sql( + table_name, + id_column, + content_column, + embedding_column, + column_configs, + primary_key, + dialect, + ) + ] + + if kind == AlgoKind.ANN: + ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_ANN( + table_name, dialect, secondary_indexes + ) + else: + ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_KNN( + table_name, embedding_column, dialect, secondary_indexes + ) + + return ddl_statements + + @staticmethod + def _generate_create_table_sql( + table_name, + id_column, + content_column, + embedding_column, + column_configs, + primary_key, + dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + ): create_table_statement = f"CREATE TABLE {table_name} (\n" if not isinstance(id_column, TableColumn): @@ -438,30 +510,66 @@ def _generate_sql( + ")" ) + return create_table_statement + + @staticmethod + def _generate_secondary_indices_ddl_KNN( + table_name, embedding_column, dialect, secondary_indexes=None + ): + if not secondary_indexes: + return [] + secondary_index_ddl_statements = [] + for secondary_index in secondary_indexes: + statement = f"CREATE INDEX {secondary_index.index_name} ON {table_name}(" + statement = statement + ",".join(secondary_index.columns) + ") " - if secondary_indexes is not None: - for secondary_index in secondary_indexes: - statement = ( - f"CREATE INDEX {secondary_index.index_name} ON {table_name}(" - ) - statement = statement + ",".join(secondary_index.columns) + ") " + if dialect == DatabaseDialect.POSTGRESQL: + statement = statement + "INCLUDE (" + else: + statement = statement + "STORING (" + + if secondary_index.storing_columns is None: + secondary_index.storing_columns = [embedding_column.name] + elif embedding_column not in secondary_index.storing_columns: + secondary_index.storing_columns.append(embedding_column.name) + + statement = statement + ",".join(secondary_index.storing_columns) + ")" + secondary_index_ddl_statements.append(statement) + return secondary_index_ddl_statements + + @staticmethod + def _generate_secondary_indices_ddl_ANN( + table_name, dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, secondary_indexes=[] + ): + if dialect != DatabaseDialect.GOOGLE_STANDARD_SQL: + raise Exception( + f"ANN is only supported for the GoogleSQL dialect not {dialect}" + ) + + secondary_index_ddl_statements = [] - if dialect == DatabaseDialect.POSTGRESQL: - statement = statement + "INCLUDE (" - else: - statement = statement + "STORING (" + for secondary_index in secondary_indexes: + statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({secondary_index.columns[0]})" + options_segments = [f"distance_type='{secondary_index.index_type}'"] + if secondary_index.tree_depth > 0: + tree_depth = secondary_index.tree_depth + if tree_depth not in ANN_ACCEPTABLE_TREE_DEPTHS: + raise Exception( + f"tree_depth: {tree_depth} is not in the acceptable values: {ANN_ACCEPTABLE_TREE_DEPTHS}" + ) + options_segments.append(f"tree_depth={secondary_index.tree_depth}") - if secondary_index.storing_columns is None: - secondary_index.storing_columns = [embedding_column.name] - elif embedding_column not in secondary_index.storing_columns: - secondary_index.storing_columns.append(embedding_column.name) + if secondary_index.num_branches > 0: + options_segments.append(f"num_branches={secondary_index.num_branches}") - statement = statement + ",".join(secondary_index.storing_columns) + ")" + if secondary_index.num_leaves > 0: + options_segments.append(f"num_leaves={secondary_index.num_leaves}") - secondary_index_ddl_statements.append(statement) + statement += "\n\tOPTIONS(" + ", ".join(options_segments) + ")" + secondary_index_ddl_statements.append(statement.strip()) - return [create_table_statement] + secondary_index_ddl_statements + return secondary_index_ddl_statements def __init__( self, diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 31f9e3d..350dfad 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -18,6 +18,8 @@ DistanceStrategy, GoogleSqlSemnatics, PGSqlSemnatics, + SecondaryIndex, + SpannerVectorStore, ) @@ -69,3 +71,37 @@ def test_distance_function_raises_exception_if_unknown(self): for strategy in strategies: with self.assertRaises(Exception): sem.getDistanceFunction(strategy) + + +class TestSpannerVectorStore_KNN(unittest.TestCase): + def test_generate_create_table_sql(self): + got = SpannerVectorStore._generate_create_table_sql( + "users", + "id", + "essays", + "science_scores", + [], + "id", + ) + want = "CREATE TABLE users (\n id STRING(36),\n essays STRING(MAX),\n science_scores ARRAY\n) PRIMARY KEY(id)" + assert got == want + + def test_generate_secondary_indices_ddl_ANN(self): + got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( + "Documents", + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + num_branches=1000, + tree_depth=3, + index_type=DistanceStrategy.COSINE, + num_leaves=100000, + ) + ], + ) + want = [ + "CREATE VECTOR INDEX DocEmbeddingIndex\n\tON Documents(DocEmbedding)\n\tOPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + + assert got == want From 8c44f445da314c16e802b26350a1f267c38e7aa8 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 3 Jan 2025 23:16:11 -0800 Subject: [PATCH 04/13] Update tests --- tests/unit/test_vectore_store.py | 94 +++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 350dfad..190139c 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License +from collections import namedtuple import unittest +from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from langchain_google_spanner.vector_store import ( DistanceStrategy, GoogleSqlSemnatics, @@ -100,8 +102,96 @@ def test_generate_secondary_indices_ddl_ANN(self): ) ], ) + want = [ - "CREATE VECTOR INDEX DocEmbeddingIndex\n\tON Documents(DocEmbedding)\n\tOPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" + "CREATE VECTOR INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + " OPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" ] - assert got == want + assert canonicalize(got) == canonicalize(want) + + def test_generate_secondary_indices_ddl_ANN_raises_exception_for_non_GoogleSQL_dialect( + self, + ): + got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( + "Documents", + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + num_branches=1000, + tree_depth=3, + index_type=DistanceStrategy.COSINE, + num_leaves=100000, + ) + ], + ) + + want = [ + "CREATE VECTOR INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + " OPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + + assert canonicalize(got) == canonicalize(want) + + def test_generate_secondary_indices_ddl_KNN_GoogleDialect(self): + got = SpannerVectorStore._generate_secondary_indices_ddl_KNN( + "Documents", + embedding_column="custom_embedding_id1", + dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + num_branches=1000, + tree_depth=3, + index_type=DistanceStrategy.COSINE, + num_leaves=100000, + ) + ], + ) + + want = [ + "CREATE INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + " OPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + + assert canonicalize(got) == canonicalize(want) + + def test_generate_secondary_indices_ddl_KNN_PostgresDialect(self): + embed_column = namedtuple("Column", ["name"]) + embed_column.name = "text" + got = SpannerVectorStore._generate_secondary_indices_ddl_KNN( + "Documents", + embedding_column=embed_column, + dialect=DatabaseDialect.POSTGRESQL, + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + num_branches=1000, + tree_depth=3, + index_type=DistanceStrategy.COSINE, + num_leaves=100000, + ) + ], + ) + + want = [ + "CREATE INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + " OPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + + assert canonicalize(got) == canonicalize(want) + + +def trimSpaces(x: str) -> str: + return x.lstrip("\n").rstrip("\n").replace("\t", " ").strip() + + +def canonicalize(s): + return list(map(trimSpaces, s)) From 10c6a3ca83ca32e5f950b8fa3a64211e26b6dc81 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sat, 4 Jan 2025 00:59:53 -0800 Subject: [PATCH 05/13] Implement query_ANN --- src/langchain_google_spanner/vector_store.py | 58 ++++++++- tests/unit/test_vectore_store.py | 128 +++++++++++-------- 2 files changed, 129 insertions(+), 57 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 080f1a6..390cf5e 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -109,9 +109,6 @@ class DistanceStrategy(Enum): COSINE = 1 EUCLIDEIAN = 2 DOT_PRODUCT = 3 - APPROX_DOT_PRODUCT = 4 - APPROX_COSINE = 5 - APPROX_EUCLIDEAN = 6 def __str__(self): return DISTANCE_STRATEGY_STRING[self] @@ -158,14 +155,17 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: _GOOGLE_DISTANCE_ALGO_NAMES = { - DistanceStrategy.APPROX_COSINE: "APPROX_COSINE_DISTANCE", - DistanceStrategy.APPROX_DOT_PRODUCT: "APPROX_DOT_PRODUCT", - DistanceStrategy.APPROX_EUCLIDEAN: "APPROX_EUCLIDEAN_DISTANCE", DistanceStrategy.COSINE: "COSINE_DISTANCE", DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN_DISTANCE", } +distance_strategy_to_ANN_function = { + DistanceStrategy.COSINE: "APPROX_COSINE_DISTANCE", + DistanceStrategy.DOT_PRODUCT: "APPROX_DOT_PRODUCT", + DistanceStrategy.EUCLIDEIAN: "APPROX_EUCLIDEAN_DISTANCE", +} + _GOOGLE_ALGO_INDEX_NAME = { DistanceStrategy.COSINE: "COSINE", DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", @@ -401,6 +401,7 @@ def _generate_sql( primary_key, secondary_indexes: Optional[List[SecondaryIndex]] = None, kind: Optional[AlgoKind] = AlgoKind.KNN, + limit=None, ): """ Generate SQL for creating the vector store table. @@ -970,6 +971,51 @@ def similarity_search_with_score_by_vector( ) return documents + @staticmethod + def _query_ANN( + column_name: str, + table_name: str, + index_name: str, + embedding: List[float], + embedding_column_name: str, + num_leaves: int, + strategy: DistanceStrategy = DistanceStrategy.COSINE, + limit: int = None, + is_embedding_nullable: bool = False, + where_condition: str = None, + ): + """ + Sample query: + SELECT DocId + FROM Documents@{FORCE_INDEX=DocEmbeddingIndex} + ORDER BY APPROX_EUCLIDEAN_DISTANCE( + ARRAY[1.0, 2.0, 3.0], DocEmbedding, + options => JSON '{"num_leaves_to_search": 10}') + LIMIT 100 + """ + + ann_strategy_name = distance_strategy_to_ANN_function.get(strategy, None) + if not ann_strategy_name: + raise Exception(f"{strategy} is not supported for ANN") + + sql = ( + f"SELECT {column_name} FROM {table_name}" + + "@{FORCE_INDEX=" + + f"{index_name}" + + "}\n" + + f" ORDER BY {ann_strategy_name}(\n" + + f" ARRAY{embedding}, {embedding_column_name}, options => JSON '" + + "{\"num_leaves_to_search\": %s})\n"%(num_leaves) + ) + + if where_condition: + sql += " WHERE " + where_condition + "\n" + + if limit: + sql += f"LIMIT {limit}" + + return sql + def _get_rows_by_similarity_search( self, embedding: List[float], diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 190139c..9ac5ef2 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -17,6 +17,7 @@ from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from langchain_google_spanner.vector_store import ( + AlgoKind, DistanceStrategy, GoogleSqlSemnatics, PGSqlSemnatics, @@ -31,9 +32,6 @@ def test_distance_function_to_string(self): (DistanceStrategy.COSINE, "COSINE_DISTANCE"), (DistanceStrategy.DOT_PRODUCT, "DOT_PRODUCT"), (DistanceStrategy.EUCLIDEIAN, "EUCLIDEAN_DISTANCE"), - (DistanceStrategy.APPROX_COSINE, "APPROX_COSINE_DISTANCE"), - (DistanceStrategy.APPROX_DOT_PRODUCT, "APPROX_DOT_PRODUCT"), - (DistanceStrategy.APPROX_EUCLIDEAN, "APPROX_EUCLIDEAN_DISTANCE"), ] sem = GoogleSqlSemnatics() @@ -65,9 +63,8 @@ def test_distance_function_to_string(self): def test_distance_function_raises_exception_if_unknown(self): strategies = [ - DistanceStrategy.APPROX_COSINE, - DistanceStrategy.APPROX_DOT_PRODUCT, - DistanceStrategy.APPROX_EUCLIDEAN, + 100, + -1, ] for strategy in strategies: @@ -89,57 +86,67 @@ def test_generate_create_table_sql(self): assert got == want def test_generate_secondary_indices_ddl_ANN(self): - got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( - "Documents", - secondary_indexes=[ - SecondaryIndex( - index_name="DocEmbeddingIndex", - columns=["DocEmbedding"], - num_branches=1000, - tree_depth=3, - index_type=DistanceStrategy.COSINE, - num_leaves=100000, - ) - ], - ) - - want = [ - "CREATE VECTOR INDEX DocEmbeddingIndex\n" - + " ON Documents(DocEmbedding)\n" - + " OPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" + strategies = [ + DistanceStrategy.COSINE, + DistanceStrategy.DOT_PRODUCT, + DistanceStrategy.EUCLIDEIAN, ] - assert canonicalize(got) == canonicalize(want) + for distance_strategy in strategies: + got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( + "Documents", + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + num_branches=1000, + tree_depth=3, + index_type=distance_strategy, + num_leaves=100000, + ) + ], + ) + + want = [ + "CREATE VECTOR INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + + assert canonicalize(got) == canonicalize(want) def test_generate_secondary_indices_ddl_ANN_raises_exception_for_non_GoogleSQL_dialect( self, ): - got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( - "Documents", - secondary_indexes=[ - SecondaryIndex( - index_name="DocEmbeddingIndex", - columns=["DocEmbedding"], - num_branches=1000, - tree_depth=3, - index_type=DistanceStrategy.COSINE, - num_leaves=100000, - ) - ], - ) - - want = [ - "CREATE VECTOR INDEX DocEmbeddingIndex\n" - + " ON Documents(DocEmbedding)\n" - + " OPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" + strategies = [ + DistanceStrategy.COSINE, + DistanceStrategy.DOT_PRODUCT, + DistanceStrategy.EUCLIDEIAN, ] - assert canonicalize(got) == canonicalize(want) + for strategy in strategies: + with self.assertRaises(Exception): + SpannerVectorStore._generate_secondary_indices_ddl_ANN( + "Documents", + dialect=DatabaseDialect.POSTGRESQL, + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + num_branches=1000, + tree_depth=3, + index_type=strategy, + num_leaves=100000, + ) + ], + ) def test_generate_secondary_indices_ddl_KNN_GoogleDialect(self): + embed_column = namedtuple("Column", ["name"]) + embed_column.name = "text" got = SpannerVectorStore._generate_secondary_indices_ddl_KNN( "Documents", - embedding_column="custom_embedding_id1", + embedding_column=embed_column, dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, secondary_indexes=[ SecondaryIndex( @@ -154,9 +161,7 @@ def test_generate_secondary_indices_ddl_KNN_GoogleDialect(self): ) want = [ - "CREATE INDEX DocEmbeddingIndex\n" - + " ON Documents(DocEmbedding)\n" - + " OPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" + "CREATE INDEX DocEmbeddingIndex ON Documents(DocEmbedding) STORING (text)" ] assert canonicalize(got) == canonicalize(want) @@ -181,13 +186,34 @@ def test_generate_secondary_indices_ddl_KNN_PostgresDialect(self): ) want = [ - "CREATE INDEX DocEmbeddingIndex\n" - + " ON Documents(DocEmbedding)\n" - + " OPTIONS(distance_type='COSINE', tree_depth=3, num_branches=1000, num_leaves=100000)" + "CREATE INDEX DocEmbeddingIndex ON Documents(DocEmbedding) INCLUDE (text)" ] assert canonicalize(got) == canonicalize(want) + def test_query_ANN(self): + got = SpannerVectorStore._query_ANN( + "DocId", + "Documents", + "DocEmbeddingIndex", + [1.0, 2.0, 3.0], + "DocEmbedding", + 10, + DistanceStrategy.COSINE, + limit=100, + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + " ORDER BY APPROX_COSINE_DISTANCE(\n" + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON '{\"num_leaves_to_search\": 10})\n" + + "LIMIT 100" + ) + + print("got", got) + print("want", want) + assert got == want + def trimSpaces(x: str) -> str: return x.lstrip("\n").rstrip("\n").replace("\t", " ").strip() From fc45986f2759040e1245d391de795c21c717d470 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sat, 4 Jan 2025 02:02:48 -0800 Subject: [PATCH 06/13] Format --- src/langchain_google_spanner/vector_store.py | 42 ++++++++++++++++++-- tests/unit/test_vectore_store.py | 3 +- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 390cf5e..1b46b90 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -963,7 +963,7 @@ def similarity_search_with_score_by_vector( List[Document]: List of documents most similar to the query. """ - results, column_order_map = self._get_rows_by_similarity_search( + results, column_order_map = self._get_rows_by_similarity_search_knn( embedding, k, pre_filter ) documents = self._get_documents_from_query_results( @@ -971,6 +971,40 @@ def similarity_search_with_score_by_vector( ) return documents + def search_by_ANN( + self, + column_name: str, + table_name: str, + index_name: str, + embedding: List[float], + embedding_column_name: str, + num_leaves: int, + strategy: DistanceStrategy = DistanceStrategy.COSINE, + limit: int = None, + is_embedding_nullable: bool = False, + where_condition: str = None, + ) -> List[Any]: + sql = SpannerVectorStore._query_ANN( + column_name, + table_name, + index_name, + embedding, + embedding_column_name, + num_leaves, + strategy, + limit, + is_embedding_nullable, + where_condition, + ) + staleness = self._query_parameters.staleness + with self._database.snapshot( + **staleness if staleness is not None else {} + ) as snapshot: + results = snapshot.execute_sql( + sql=sql_query, + ) + return list(results) + @staticmethod def _query_ANN( column_name: str, @@ -1005,7 +1039,7 @@ def _query_ANN( + "}\n" + f" ORDER BY {ann_strategy_name}(\n" + f" ARRAY{embedding}, {embedding_column_name}, options => JSON '" - + "{\"num_leaves_to_search\": %s})\n"%(num_leaves) + + '{"num_leaves_to_search": %s})\n' % (num_leaves) ) if where_condition: @@ -1016,7 +1050,7 @@ def _query_ANN( return sql - def _get_rows_by_similarity_search( + def _get_rows_by_similarity_search_knn( self, embedding: List[float], k: int, @@ -1193,7 +1227,7 @@ def max_marginal_relevance_search_with_score_by_vector( List of Documents and similarity scores selected by maximal marginal relevance and score for each. """ - results, column_order_map = self._get_rows_by_similarity_search( + results, column_order_map = self._get_rows_by_similarity_search_knn( embedding, fetch_k, pre_filter ) diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 9ac5ef2..b4bb433 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -16,6 +16,7 @@ import unittest from google.cloud.spanner_admin_database_v1.types import DatabaseDialect + from langchain_google_spanner.vector_store import ( AlgoKind, DistanceStrategy, @@ -206,7 +207,7 @@ def test_query_ANN(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + " ORDER BY APPROX_COSINE_DISTANCE(\n" - + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON '{\"num_leaves_to_search\": 10})\n" + + ' ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n' + "LIMIT 100" ) From 64358abf28352e3831a0eb5ad1083477eafceae5 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sun, 19 Jan 2025 23:44:54 -0800 Subject: [PATCH 07/13] ANN: account for nullable columns in search and index creation --- README.rst | 3 +- src/langchain_google_spanner/vector_store.py | 27 ++++++- tests/unit/test_vectore_store.py | 79 ++++++++++++++------ 3 files changed, 81 insertions(+), 28 deletions(-) diff --git a/README.rst b/README.rst index f09c809..a379847 100644 --- a/README.rst +++ b/README.rst @@ -210,4 +210,5 @@ This is not an officially supported Google product. Limitations ---------- -* Approximate Nearest Neighbors (ANN) strategies are only support for the GoogleSQL dialect +* Approximate Nearest Neighbors (ANN) strategies are only supported for the GoogleSQL dialect +* ANN's `ALTER VECTOR INDEX` is not supported by [Google Cloud Spanner](https://cloud.google.com/spanner/docs/find-approximate-nearest-neighbors#limitations) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 1b46b90..de5df68 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -88,6 +88,7 @@ class SecondaryIndex: columns: list[str] storing_columns: Optional[list[str]] = None num_leaves: Optional[int] = None # Only necessary for ANN + nullable_column: Optional[bool] = False # Only necessary for ANN num_branches: Optional[int] = None # Only necessary for ANN tree_depth: Optional[int] = None # Only necessary for ANN index_type: Optional[DistanceStrategy] = None # Only necessary for ANN @@ -551,7 +552,10 @@ def _generate_secondary_indices_ddl_ANN( secondary_index_ddl_statements = [] for secondary_index in secondary_indexes: - statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({secondary_index.columns[0]})" + column_name = secondary_index.columns[0] + statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({column_name})" + if secondary_index.nullable_column: + statement += f"\n\tWHERE {column_name} IS NOT NULL" options_segments = [f"distance_type='{secondary_index.index_type}'"] if secondary_index.tree_depth > 0: tree_depth = secondary_index.tree_depth @@ -983,6 +987,7 @@ def search_by_ANN( limit: int = None, is_embedding_nullable: bool = False, where_condition: str = None, + column_is_nullable: bool = False, ) -> List[Any]: sql = SpannerVectorStore._query_ANN( column_name, @@ -995,6 +1000,7 @@ def search_by_ANN( limit, is_embedding_nullable, where_condition, + column_is_nullable=column_is_nullable, ) staleness = self._query_parameters.staleness with self._database.snapshot( @@ -1017,6 +1023,7 @@ def _query_ANN( limit: int = None, is_embedding_nullable: bool = False, where_condition: str = None, + column_is_nullable: bool = False, ): """ Sample query: @@ -1026,6 +1033,16 @@ def _query_ANN( ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON '{"num_leaves_to_search": 10}') LIMIT 100 + + OR + + SELECT DocId + FROM Documents@{FORCE_INDEX=DocEmbeddingIndex} + WHERE NullableDocEmbedding IS NOT NULL + ORDER BY APPROX_EUCLIDEAN_DISTANCE( + ARRAY[1.0, 2.0, 3.0], NullableDocEmbedding, + options => JSON '{"num_leaves_to_search": 10}') + LIMIT 100 """ ann_strategy_name = distance_strategy_to_ANN_function.get(strategy, None) @@ -1036,8 +1053,12 @@ def _query_ANN( f"SELECT {column_name} FROM {table_name}" + "@{FORCE_INDEX=" + f"{index_name}" - + "}\n" - + f" ORDER BY {ann_strategy_name}(\n" + + ( + "}\n" + if (not column_is_nullable) + else "}\nWHERE " + f"{embedding_column_name} IS NOT NULL\n" + ) + + f"ORDER BY {ann_strategy_name}(\n" + f" ARRAY{embedding}, {embedding_column_name}, options => JSON '" + '{"num_leaves_to_search": %s})\n' % (num_leaves) ) diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index b4bb433..c86bef0 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -93,28 +93,38 @@ def test_generate_secondary_indices_ddl_ANN(self): DistanceStrategy.EUCLIDEIAN, ] + nullables = [True, False] for distance_strategy in strategies: - got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( - "Documents", - secondary_indexes=[ - SecondaryIndex( - index_name="DocEmbeddingIndex", - columns=["DocEmbedding"], - num_branches=1000, - tree_depth=3, - index_type=distance_strategy, - num_leaves=100000, - ) - ], - ) - - want = [ - "CREATE VECTOR INDEX DocEmbeddingIndex\n" - + " ON Documents(DocEmbedding)\n" - + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" - ] - - assert canonicalize(got) == canonicalize(want) + for nullable in nullables: + got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( + "Documents", + secondary_indexes=[ + SecondaryIndex( + index_name="DocEmbeddingIndex", + columns=["DocEmbedding"], + nullable_column=nullable, + num_branches=1000, + tree_depth=3, + index_type=distance_strategy, + num_leaves=100000, + ) + ], + ) + + want = [ + "CREATE VECTOR INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + " WHERE DocEmbedding IS NOT NULL\n" + + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + if not nullable: + want = [ + "CREATE VECTOR INDEX DocEmbeddingIndex\n" + + " ON Documents(DocEmbedding)\n" + + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" + ] + + assert canonicalize(got) == canonicalize(want) def test_generate_secondary_indices_ddl_ANN_raises_exception_for_non_GoogleSQL_dialect( self, @@ -206,13 +216,34 @@ def test_query_ANN(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" - + " ORDER BY APPROX_COSINE_DISTANCE(\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + + ' ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n' + + "LIMIT 100" + ) + + assert got == want + + def test_query_ANN_column_is_nullable(self): + got = SpannerVectorStore._query_ANN( + "DocId", + "Documents", + "DocEmbeddingIndex", + [1.0, 2.0, 3.0], + "DocEmbedding", + 10, + DistanceStrategy.COSINE, + limit=100, + column_is_nullable=True, + ) + + want = ( + "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + + "WHERE DocEmbedding IS NOT NULL\n" + + "ORDER BY APPROX_COSINE_DISTANCE(\n" + ' ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n' + "LIMIT 100" ) - print("got", got) - print("want", want) assert got == want From 0a1982aca90fad04f9ce3b8b2165c21d0d4529ae Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 20 Jan 2025 02:33:19 -0800 Subject: [PATCH 08/13] Add integration tests and some TODOs --- .../integration/test_spanner_vector_store.py | 188 +++++++++++++++++- 1 file changed, 187 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index d3f6cfc..7a16fc4 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -26,6 +26,7 @@ QueryParameters, SpannerVectorStore, TableColumn, + AlgoKind, ) project_id = os.environ["PROJECT_ID"] @@ -33,6 +34,7 @@ google_database = os.environ["GOOGLE_DATABASE"] pg_database = os.environ["PG_DATABASE"] table_name = "test_table" + str(uuid.uuid4()).replace("-", "_") +table_name_ANN = table_name + "_ANN" OPERATION_TIMEOUT_SECONDS = 240 @@ -207,7 +209,7 @@ def test_init_vector_store_table4(self): ) -class TestSpannerVectorStoreGoogleSQL: +class TestSpannerVectorStoreGoogleSQL_KNN: @pytest.fixture(scope="class") def setup_database(self, client): SpannerVectorStore.init_vector_store_table( @@ -389,6 +391,190 @@ def test_spanner_vector_search_data4(self, setup_database): assert len(docs) == 3 +class TestSpannerVectorStoreGoogleSQL_ANN: + @pytest.fixture(scope="class") + def setup_database(self, client): + SpannerVectorStore.init_vector_store_table( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column="row_id", + metadata_columns=[ + TableColumn(name="metadata", type="JSON", is_null=True), + TableColumn(name="title", type="STRING(MAX)", is_null=False), + ], + kind=AlgoKind.ANN, + ) + + loader = HNLoader("https://news.ycombinator.com/item?id=34817881") + + embeddings = FakeEmbeddings(size=3) + + yield loader, embeddings + + print("\nPerforming GSQL cleanup after each ANN test...") + + database = client.instance(instance_id).database(google_database) + operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name_ANN}"]) + operation.result(OPERATION_TIMEOUT_SECONDS) + + # Code to perform teardown after each test goes here + print("\nGSQL Cleanup complete.") + + def test_add_data1(self, setup_database): + loader, embeddings = setup_database + + db = SpannerVectorStore( + instance_id=instance_id, + database_id=google_database, + table_name=table_name_ANN, + id_column="row_id", + ignore_metadata_columns=[], + embedding_service=embeddings, + metadata_json_column="metadata", + kind=AlgoKind.ANN, + ) + + docs = loader.load() + ids = [str(uuid.uuid4()) for _ in range(len(docs))] + ids_row_inserted = db.add_documents(documents=docs, ids=ids) + assert ids == ids_row_inserted + + def test_add_data2(self, setup_database): + loader, embeddings = setup_database + + db = SpannerVectorStore( + instance_id=instance_id, + database_id=google_database, + table_name=table_name, + id_column="row_id", + ignore_metadata_columns=[], + embedding_service=embeddings, + metadata_json_column="metadata", + kind=AlgoKind.ANN, + ) + + texts = [ + "Langchain Test Text 1", + "Langchain Test Text 2", + "Langchain Test Text 3", + ] + ids = [str(uuid.uuid4()) for _ in range(len(texts))] + ids_row_inserted = db.add_texts( + texts=texts, + ids=ids, + metadatas=[ + {"title": "Title 1"}, + {"title": "Title 2"}, + {"title": "Title 3"}, + ], + ) + assert ids == ids_row_inserted + + def test_delete_data(self, setup_database): + loader, embeddings = setup_database + + db = SpannerVectorStore( + instance_id=instance_id, + database_id=google_database, + table_name=table_name, + id_column="row_id", + ignore_metadata_columns=[], + embedding_service=embeddings, + metadata_json_column="metadata", + kind=AlgoKind.ANN, + ) + + docs = loader.load() + deleted = db.delete(documents=[docs[0], docs[1]]) + assert deleted == True + + def test_search_data1(self, setup_database): + # loader, embeddings = setup_database + # db = SpannerVectorStore( + # instance_id=instance_id, + # database_id=google_database, + # table_name=table_name, + # id_column="row_id", + # ignore_metadata_columns=[], + # embedding_service=embeddings, + # metadata_json_column="metadata", + # kind=AlgoKind.ANN, + # ) + # docs = db.similarity_search( + # "Testing the langchain integration with spanner", k=2 + # ) + # assert len(docs) == 2 + pass + + def test_search_data2(self, setup_database): + # TODO: Implement me + # loader, embeddings = setup_database + # db = SpannerVectorStore( + # instance_id=instance_id, + # database_id=google_database, + # table_name=table_name, + # id_column="row_id", + # ignore_metadata_columns=[], + # embedding_service=embeddings, + # metadata_json_column="metadata", + # kind=AlgoKind.ANN, + # ) + # embeds = embeddings.embed_query( + # "Testing the langchain integration with spanner" + # ) + # docs = db.similarity_search_by_vector(embeds, k=3, pre_filter="1 = 1") + # assert len(docs) == 3 + pass + + def test_search_data3(self, setup_database): + # TODO: Implement me + # loader, embeddings = setup_database + # db = SpannerVectorStore( + # instance_id=instance_id, + # database_id=google_database, + # table_name=table_name, + # id_column="row_id", + # ignore_metadata_columns=[], + # embedding_service=embeddings, + # metadata_json_column="metadata", + # query_parameters=QueryParameters( + # distance_strategy=DistanceStrategy.COSINE, + # max_staleness=datetime.timedelta(seconds=15), + # ), + # kind=AlgoKind.ANN, + # ) + # + # docs = db.similarity_search( + # "Testing the langchain integration with spanner", k=3 + # ) + # + # assert len(docs) == 3 + pass + + def test_search_data4(self, setup_database): + # loader, embeddings = setup_database + # db = SpannerVectorStore( + # instance_id=instance_id, + # database_id=google_database, + # table_name=table_name, + # id_column="row_id", + # ignore_metadata_columns=[], + # embedding_service=embeddings, + # metadata_json_column="metadata", + # query_parameters=QueryParameters( + # distance_strategy=DistanceStrategy.COSINE, + # max_staleness=datetime.timedelta(seconds=15), + # ), + # kind=AlgoKind.ANN, + # ) + # docs = db.max_marginal_relevance_search( + # "Testing the langchain integration with spanner", k=3 + # ) + # assert len(docs) == 3 + pass + + class TestSpannerVectorStorePGSQL: @pytest.fixture(scope="class") def setup_database(self, client): From 44d0996f0e44c9f5702555621603b1f362237c8b Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Tue, 21 Jan 2025 01:47:14 -0800 Subject: [PATCH 09/13] Apply some changes from code review --- src/langchain_google_spanner/vector_store.py | 99 ++++++++++---------- tests/unit/test_vectore_store.py | 54 ++++++----- 2 files changed, 80 insertions(+), 73 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index de5df68..41e21db 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -87,11 +87,6 @@ class SecondaryIndex: index_name: str columns: list[str] storing_columns: Optional[list[str]] = None - num_leaves: Optional[int] = None # Only necessary for ANN - nullable_column: Optional[bool] = False # Only necessary for ANN - num_branches: Optional[int] = None # Only necessary for ANN - tree_depth: Optional[int] = None # Only necessary for ANN - index_type: Optional[DistanceStrategy] = None # Only necessary for ANN def __post_init__(self): # Check if column_name is None after initialization @@ -102,24 +97,42 @@ def __post_init__(self): raise ValueError("Index Columns can't be None") +@dataclass +class VectorSearchIndex: + """ + The index for use with Approximate Nearest Neighbor (ANN) vector search. + """ + index_name: str + columns: list[str] + num_leaves: int + num_branches: int + tree_depth: int + index_type: DistanceStrategy + nullable_column: bool = False + + def __post_init__(self): + if self.index_name is None: + raise ValueError("index_name must be set") + + if len(self.columns) == 0: + raise ValueError("columns must be set") + + ok_tree_depth = self.tree_depth in (2, 3) + if not ok_tree_depth: + raise ValueError("tree_depth must be either 2 or 3") + + class DistanceStrategy(Enum): """ Enum for distance calculation strategies. """ COSINE = 1 - EUCLIDEIAN = 2 + EUCLIDEAN = 2 DOT_PRODUCT = 3 def __str__(self): - return DISTANCE_STRATEGY_STRING[self] - - -DISTANCE_STRATEGY_STRING = { - DistanceStrategy.COSINE: "COSINE", - DistanceStrategy.EUCLIDEIAN: "EUCLIDEIAN", - DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", -} + return self.name class DialectSemantics(ABC): @@ -128,7 +141,7 @@ class DialectSemantics(ABC): """ @abstractmethod - def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: + def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: """ Abstract method to get the distance function based on the provided distance strategy. @@ -155,22 +168,18 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: ) -_GOOGLE_DISTANCE_ALGO_NAMES = { +# Maps between distance strategy enums and the appropriate vector search index name. +GOOGLE_DIALECT_DISTANCE_FUCNTIONS = { DistanceStrategy.COSINE: "COSINE_DISTANCE", DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", - DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN_DISTANCE", + DistanceStrategy.EUCLIDEAN: "EUCLIDEAN_DISTANCE", } +# Maps between distance strategy and the appropriate ANN search function name. distance_strategy_to_ANN_function = { DistanceStrategy.COSINE: "APPROX_COSINE_DISTANCE", DistanceStrategy.DOT_PRODUCT: "APPROX_DOT_PRODUCT", - DistanceStrategy.EUCLIDEIAN: "APPROX_EUCLIDEAN_DISTANCE", -} - -_GOOGLE_ALGO_INDEX_NAME = { - DistanceStrategy.COSINE: "COSINE", - DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", - DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN", + DistanceStrategy.EUCLIDEAN: "APPROX_EUCLIDEAN_DISTANCE", } @@ -179,8 +188,8 @@ class GoogleSqlSemnatics(DialectSemantics): Implementation of dialect semantics for Google SQL. """ - def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: - return _GOOGLE_DISTANCE_ALGO_NAMES.get(distance_strategy, "EUCLIDEAN") + def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: + return GOOGLE_DIALECT_DISTANCE_FUCNTIONS.get(distance_strategy, "EUCLIDEAN") def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]: where_clause_condition = " AND ".join( @@ -201,10 +210,11 @@ def getIndexDistanceType(self, distance_strategy) -> str: return value -_PG_DISTANCE_ALGO_NAMES = { +# Maps between DistanceStrategy and the expected PostgreSQL distance equivalent. +PG_DIALECT_DISTANCE_FUNCTIONS = { DistanceStrategy.COSINE: "spanner.cosine_distance", DistanceStrategy.DOT_PRODUCT: "spanner.dot_product", - DistanceStrategy.EUCLIDEIAN: "spanner.euclidean_distance", + DistanceStrategy.EUCLIDEAN: "spanner.euclidean_distance", } @@ -213,8 +223,8 @@ class PGSqlSemnatics(DialectSemantics): Implementation of dialect semantics for PostgreSQL. """ - def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str: - name = _PG_DISTANCE_ALGO_NAMES.get(distance_strategy, None) + def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: + name = PG_DIALECT_DISTANCE_FUNCTIONS.get(distance_strategy, None) if name is None: raise Exception( "Unsupported PostgreSQL distance strategy: {}".format(distance_strategy) @@ -254,7 +264,7 @@ class QueryParameters: class NearestNeighborsAlgorithm(Enum): """ - Enum for nearest neighbors search algorithms. + Enum for k-nearest neighbors search algorithms. """ EXACT_NEAREST_NEIGHBOR = 1 @@ -263,7 +273,7 @@ class NearestNeighborsAlgorithm(Enum): def __init__( self, algorithm=NearestNeighborsAlgorithm.EXACT_NEAREST_NEIGHBOR, - distance_strategy=DistanceStrategy.EUCLIDEIAN, + distance_strategy=DistanceStrategy.EUCLIDEAN, read_timestamp: Optional[datetime.datetime] = None, min_read_timestamp: Optional[datetime.datetime] = None, max_staleness: Optional[datetime.timedelta] = None, @@ -303,10 +313,6 @@ def __init__( self.staleness = {key: value} -DEFAULT_ANN_TREE_DEPTH = 2 -ANN_ACCEPTABLE_TREE_DEPTHS = (2, 3) - - class AlgoKind(Enum): KNN = 0 ANN = 1 @@ -341,8 +347,8 @@ def init_vector_store_table( metadata_columns: Optional[List[TableColumn]] = None, primary_key: Optional[str] = None, vector_size: Optional[int] = None, - secondary_indexes: Optional[List[SecondaryIndex]] = None, - kind: AlgoKind = None, + secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None, + kind: AlgoKind = AlgoKind.KNN, ) -> bool: """ Initialize the vector store new table in Google Cloud Spanner. @@ -357,6 +363,7 @@ def init_vector_store_table( - embedding_column (str): The name of the embedding column. Defaults to EMBEDDING_COLUMN_NAME. - metadata_columns (Optional[List[Tuple]]): List of tuples containing metadata column information. Defaults to None. - vector_size (Optional[int]): The size of the vector. Defaults to None. + - kind (AlgoKind): Defines whether to use k-Nearest Neighbors or Approximate Nearest Neighbors. Defaults to kNN. """ client = client_with_user_agent(client, USER_AGENT_VECTOR_STORE) @@ -400,7 +407,7 @@ def _generate_sql( embedding_column, column_configs, primary_key, - secondary_indexes: Optional[List[SecondaryIndex]] = None, + secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None, kind: Optional[AlgoKind] = AlgoKind.KNN, limit=None, ): @@ -546,7 +553,7 @@ def _generate_secondary_indices_ddl_ANN( ): if dialect != DatabaseDialect.GOOGLE_STANDARD_SQL: raise Exception( - f"ANN is only supported for the GoogleSQL dialect not {dialect}" + f"ANN is only supported for the GoogleSQL dialect not {dialect}. File an issue on Github?" ) secondary_index_ddl_statements = [] @@ -554,15 +561,13 @@ def _generate_secondary_indices_ddl_ANN( for secondary_index in secondary_indexes: column_name = secondary_index.columns[0] statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({column_name})" - if secondary_index.nullable_column: + if getattr(secondary_index, "nullable_column", False): statement += f"\n\tWHERE {column_name} IS NOT NULL" options_segments = [f"distance_type='{secondary_index.index_type}'"] - if secondary_index.tree_depth > 0: + if getattr(secondary_index, "tree_depth", 0) > 0: tree_depth = secondary_index.tree_depth - if tree_depth not in ANN_ACCEPTABLE_TREE_DEPTHS: - raise Exception( - f"tree_depth: {tree_depth} is not in the acceptable values: {ANN_ACCEPTABLE_TREE_DEPTHS}" - ) + if tree_depth not in (2, 3): + raise Exception(f"tree_depth: {tree_depth} must be either 2 or 3") options_segments.append(f"tree_depth={secondary_index.tree_depth}") if secondary_index.num_branches > 0: @@ -761,7 +766,7 @@ def _validate_table_schema(self, column_type_map, types, default_columns): def _select_relevance_score_fn(self) -> Callable[[float], float]: if self._query_parameters.distance_strategy == DistanceStrategy.COSINE: return self._cosine_relevance_score_fn - elif self._query_parameters.distance_strategy == DistanceStrategy.EUCLIDEIAN: + elif self._query_parameters.distance_strategy == DistanceStrategy.EUCLIDEAN: return self._euclidean_relevance_score_fn else: raise Exception( diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index c86bef0..31a0b03 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -18,12 +18,12 @@ from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from langchain_google_spanner.vector_store import ( - AlgoKind, DistanceStrategy, GoogleSqlSemnatics, PGSqlSemnatics, SecondaryIndex, SpannerVectorStore, + VectorSearchIndex, ) @@ -32,7 +32,7 @@ def test_distance_function_to_string(self): cases = [ (DistanceStrategy.COSINE, "COSINE_DISTANCE"), (DistanceStrategy.DOT_PRODUCT, "DOT_PRODUCT"), - (DistanceStrategy.EUCLIDEIAN, "EUCLIDEAN_DISTANCE"), + (DistanceStrategy.EUCLIDEAN, "EUCLIDEAN_DISTANCE"), ] sem = GoogleSqlSemnatics() @@ -46,18 +46,19 @@ def test_distance_function_to_string(self): class TestPGSqlSemnatics(unittest.TestCase): + sem = PGSqlSemnatics() + def test_distance_function_to_string(self): cases = [ (DistanceStrategy.COSINE, "spanner.cosine_distance"), (DistanceStrategy.DOT_PRODUCT, "spanner.dot_product"), - (DistanceStrategy.EUCLIDEIAN, "spanner.euclidean_distance"), + (DistanceStrategy.EUCLIDEAN, "spanner.euclidean_distance"), ] - sem = PGSqlSemnatics() got_results = [] want_results = [] for strategy, want_str in cases: - got_results.append(sem.getDistanceFunction(strategy)) + got_results.append(self.sem.getDistanceFunction(strategy)) want_results.append(want_str) assert got_results == want_results @@ -70,7 +71,7 @@ def test_distance_function_raises_exception_if_unknown(self): for strategy in strategies: with self.assertRaises(Exception): - sem.getDistanceFunction(strategy) + self.sem.getDistanceFunction(strategy) class TestSpannerVectorStore_KNN(unittest.TestCase): @@ -83,14 +84,17 @@ def test_generate_create_table_sql(self): [], "id", ) - want = "CREATE TABLE users (\n id STRING(36),\n essays STRING(MAX),\n science_scores ARRAY\n) PRIMARY KEY(id)" + want = ( + "CREATE TABLE users (\n id STRING(36),\n essays STRING(MAX)," + + "\n science_scores ARRAY\n) PRIMARY KEY(id)" + ) assert got == want def test_generate_secondary_indices_ddl_ANN(self): strategies = [ DistanceStrategy.COSINE, DistanceStrategy.DOT_PRODUCT, - DistanceStrategy.EUCLIDEIAN, + DistanceStrategy.EUCLIDEAN, ] nullables = [True, False] @@ -99,7 +103,7 @@ def test_generate_secondary_indices_ddl_ANN(self): got = SpannerVectorStore._generate_secondary_indices_ddl_ANN( "Documents", secondary_indexes=[ - SecondaryIndex( + VectorSearchIndex( index_name="DocEmbeddingIndex", columns=["DocEmbedding"], nullable_column=nullable, @@ -115,24 +119,26 @@ def test_generate_secondary_indices_ddl_ANN(self): "CREATE VECTOR INDEX DocEmbeddingIndex\n" + " ON Documents(DocEmbedding)\n" + " WHERE DocEmbedding IS NOT NULL\n" - + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" + + f" OPTIONS(distance_type='{distance_strategy}', " + + "tree_depth=3, num_branches=1000, num_leaves=100000)" ] if not nullable: want = [ "CREATE VECTOR INDEX DocEmbeddingIndex\n" + " ON Documents(DocEmbedding)\n" - + f" OPTIONS(distance_type='{distance_strategy}', tree_depth=3, num_branches=1000, num_leaves=100000)" + + f" OPTIONS(distance_type='{distance_strategy}', " + + "tree_depth=3, num_branches=1000, num_leaves=100000)" ] assert canonicalize(got) == canonicalize(want) - def test_generate_secondary_indices_ddl_ANN_raises_exception_for_non_GoogleSQL_dialect( + def test_generate_ANN_indices_exception_for_non_GoogleSQL_dialect( self, ): strategies = [ DistanceStrategy.COSINE, DistanceStrategy.DOT_PRODUCT, - DistanceStrategy.EUCLIDEIAN, + DistanceStrategy.EUCLIDEAN, ] for strategy in strategies: @@ -141,7 +147,7 @@ def test_generate_secondary_indices_ddl_ANN_raises_exception_for_non_GoogleSQL_d "Documents", dialect=DatabaseDialect.POSTGRESQL, secondary_indexes=[ - SecondaryIndex( + VectorSearchIndex( index_name="DocEmbeddingIndex", columns=["DocEmbedding"], num_branches=1000, @@ -163,16 +169,13 @@ def test_generate_secondary_indices_ddl_KNN_GoogleDialect(self): SecondaryIndex( index_name="DocEmbeddingIndex", columns=["DocEmbedding"], - num_branches=1000, - tree_depth=3, - index_type=DistanceStrategy.COSINE, - num_leaves=100000, ) ], ) want = [ - "CREATE INDEX DocEmbeddingIndex ON Documents(DocEmbedding) STORING (text)" + "CREATE INDEX DocEmbeddingIndex ON " + + "Documents(DocEmbedding) STORING (text)" ] assert canonicalize(got) == canonicalize(want) @@ -188,16 +191,13 @@ def test_generate_secondary_indices_ddl_KNN_PostgresDialect(self): SecondaryIndex( index_name="DocEmbeddingIndex", columns=["DocEmbedding"], - num_branches=1000, - tree_depth=3, - index_type=DistanceStrategy.COSINE, - num_leaves=100000, ) ], ) want = [ - "CREATE INDEX DocEmbeddingIndex ON Documents(DocEmbedding) INCLUDE (text)" + "CREATE INDEX DocEmbeddingIndex ON " + + "Documents(DocEmbedding) INCLUDE (text)" ] assert canonicalize(got) == canonicalize(want) @@ -217,7 +217,8 @@ def test_query_ANN(self): want = ( "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" - + ' ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n' + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + '\'{"num_leaves_to_search": 10})\n' + "LIMIT 100" ) @@ -240,7 +241,8 @@ def test_query_ANN_column_is_nullable(self): "SELECT DocId FROM Documents@{FORCE_INDEX=DocEmbeddingIndex}\n" + "WHERE DocEmbedding IS NOT NULL\n" + "ORDER BY APPROX_COSINE_DISTANCE(\n" - + ' ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON \'{"num_leaves_to_search": 10})\n' + + " ARRAY[1.0, 2.0, 3.0], DocEmbedding, options => JSON " + + '\'{"num_leaves_to_search": 10})\n' + "LIMIT 100" ) From ab5cc5050938a6a3b48345242e417d101ad1d04e Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sat, 25 Jan 2025 14:50:20 +0200 Subject: [PATCH 10/13] Make distance strategy dicts have consistent names --- src/langchain_google_spanner/vector_store.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index 41e21db..f99205c 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -102,6 +102,7 @@ class VectorSearchIndex: """ The index for use with Approximate Nearest Neighbor (ANN) vector search. """ + index_name: str columns: list[str] num_leaves: int @@ -169,14 +170,14 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]: # Maps between distance strategy enums and the appropriate vector search index name. -GOOGLE_DIALECT_DISTANCE_FUCNTIONS = { +GOOGLE_DIALECT_TO_KNN_DISTANCE_FUNCTIONS = { DistanceStrategy.COSINE: "COSINE_DISTANCE", DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT", DistanceStrategy.EUCLIDEAN: "EUCLIDEAN_DISTANCE", } # Maps between distance strategy and the appropriate ANN search function name. -distance_strategy_to_ANN_function = { +GOOGLE_DIALECT_TO_ANN_DISTANCE_FUNCTIONS = { DistanceStrategy.COSINE: "APPROX_COSINE_DISTANCE", DistanceStrategy.DOT_PRODUCT: "APPROX_DOT_PRODUCT", DistanceStrategy.EUCLIDEAN: "APPROX_EUCLIDEAN_DISTANCE", @@ -189,7 +190,9 @@ class GoogleSqlSemnatics(DialectSemantics): """ def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: - return GOOGLE_DIALECT_DISTANCE_FUCNTIONS.get(distance_strategy, "EUCLIDEAN") + return GOOGLE_DIALECT_TO_KNN_DISTANCE_FUNCTIONS.get( + distance_strategy, "EUCLIDEAN" + ) def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]: where_clause_condition = " AND ".join( @@ -211,7 +214,7 @@ def getIndexDistanceType(self, distance_strategy) -> str: # Maps between DistanceStrategy and the expected PostgreSQL distance equivalent. -PG_DIALECT_DISTANCE_FUNCTIONS = { +PG_DIALECT_TO_KNN_DISTANCE_FUNCTIONS = { DistanceStrategy.COSINE: "spanner.cosine_distance", DistanceStrategy.DOT_PRODUCT: "spanner.dot_product", DistanceStrategy.EUCLIDEAN: "spanner.euclidean_distance", @@ -224,7 +227,7 @@ class PGSqlSemnatics(DialectSemantics): """ def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str: - name = PG_DIALECT_DISTANCE_FUNCTIONS.get(distance_strategy, None) + name = PG_DIALECT_TO_KNN_DISTANCE_FUNCTIONS.get(distance_strategy, None) if name is None: raise Exception( "Unsupported PostgreSQL distance strategy: {}".format(distance_strategy) @@ -1050,7 +1053,7 @@ def _query_ANN( LIMIT 100 """ - ann_strategy_name = distance_strategy_to_ANN_function.get(strategy, None) + ann_strategy_name = GOOGLE_DIALECT_TO_ANN_DISTANCE_FUNCTIONS.get(strategy, None) if not ann_strategy_name: raise Exception(f"{strategy} is not supported for ANN") From 70f03d6c008dea7eb2d5e77e0eeb2031e443745f Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Sat, 25 Jan 2025 22:36:27 +0200 Subject: [PATCH 11/13] Link up integration tests for ANN to a live database --- noxfile.py | 36 +++ requirements.txt | 2 +- src/langchain_google_spanner/vector_store.py | 72 ++++-- .../integration/test_spanner_vector_store.py | 229 +++++++++++++++--- tests/unit/test_vectore_store.py | 2 +- 5 files changed, 295 insertions(+), 46 deletions(-) diff --git a/noxfile.py b/noxfile.py index 7eaecb3..2acdb42 100644 --- a/noxfile.py +++ b/noxfile.py @@ -180,6 +180,7 @@ def unit(session): "py.test", "--quiet", os.path.join("tests", "unit"), + *session.posargs, ) @@ -203,3 +204,38 @@ def install_unittest_dependencies(session, *constraints): "-r", "requirements.txt", ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def integration(session): + install_integrationtest_dependencies(session) + session.run( + "py.test", + "--quiet", + os.path.join("tests", "integration"), + *session.posargs, + ) + + +INTEGRATION_TEST_STANDARD_DEPENDENCIES = [ + "langchain_google_vertexai", + "pytest", +] +INTEGRATION_TEST_DEPENDENCIES: List[str] = [] + + +def install_integrationtest_dependencies(session, *constraints): + standard_deps = ( + INTEGRATION_TEST_STANDARD_DEPENDENCIES + INTEGRATION_TEST_DEPENDENCIES + ) + session.install(*standard_deps, *constraints) + session.run( + "pip", + "install", + "--no-compile", # To ensure no byte recompliation which is usually super slow + "-q", + "--disable-pip-version-check", # Avoid the slow version check + ".", + "-r", + "requirements.txt", + ) diff --git a/requirements.txt b/requirements.txt index 9e16179..23f0b94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ google-cloud-spanner==3.49.1 -langchain-core==0.3.9 +langchain-core==0.3.31 langchain-community==0.3.1 pydantic==2.9.1 diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index f99205c..fb80763 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -67,11 +67,13 @@ class TableColumn: column_name (str): The name of the column. type (str): The type of the column. is_null (bool): Indicates whether the column allows null values. + vector_length Optional(int): for ANN, mandatory and must be >=1 for the embedding column. """ name: str type: str is_null: bool = True + vector_length: int = None def __post_init__(self): # Check if column_name is None after initialization @@ -81,6 +83,9 @@ def __post_init__(self): if self.type is None: raise ValueError("type is mandatory and cannot be None.") + if (self.vector_length is not None) and (self.vector_length <= 0): + raise ValueError("vector_length must be >=1") + @dataclass class SecondaryIndex: @@ -104,7 +109,7 @@ class VectorSearchIndex: """ index_name: str - columns: list[str] + columns: list[str] # Each column passed in must have ARRAY type. num_leaves: int num_branches: int tree_depth: int @@ -351,7 +356,6 @@ def init_vector_store_table( primary_key: Optional[str] = None, vector_size: Optional[int] = None, secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None, - kind: AlgoKind = AlgoKind.KNN, ) -> bool: """ Initialize the vector store new table in Google Cloud Spanner. @@ -365,8 +369,7 @@ def init_vector_store_table( - content_column (str): The name of the content column. Defaults to CONTENT_COLUMN_NAME. - embedding_column (str): The name of the embedding column. Defaults to EMBEDDING_COLUMN_NAME. - metadata_columns (Optional[List[Tuple]]): List of tuples containing metadata column information. Defaults to None. - - vector_size (Optional[int]): The size of the vector. Defaults to None. - - kind (AlgoKind): Defines whether to use k-Nearest Neighbors or Approximate Nearest Neighbors. Defaults to kNN. + - vector_size (Optional[int]): The size of the vector for KNN. Defaults to None. """ client = client_with_user_agent(client, USER_AGENT_VECTOR_STORE) @@ -391,9 +394,9 @@ def init_vector_store_table( metadata_columns, primary_key, secondary_indexes, - kind=kind, ) + print("ddl", "\n".join(ddl)) operation = database.update_ddl(ddl) print("Waiting for operation to complete...") @@ -411,8 +414,6 @@ def _generate_sql( column_configs, primary_key, secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None, - kind: Optional[AlgoKind] = AlgoKind.KNN, - limit=None, ): """ Generate SQL for creating the vector store table. @@ -429,6 +430,16 @@ def _generate_sql( - str: The generated SQL. """ + # 1. If any of the columns is a VectorSearchIndex + embedding_config = list( + filter(lambda x: x.name == embedding_column, column_configs) + ) + print("column_configs", column_configs, "\nembedding_config", embedding_config) + if embedding_column and len(embedding_config) > 0: + config = embedding_config[0] + if config.vector_length is None or config.vector_length <= 0: + raise ValueError("vector_length is mandatory and must be >=1") + ddl_statements = [ SpannerVectorStore._generate_create_table_sql( table_name, @@ -441,14 +452,26 @@ def _generate_sql( ) ] - if kind == AlgoKind.ANN: - ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_ANN( - table_name, dialect, secondary_indexes - ) - else: - ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_KNN( - table_name, embedding_column, dialect, secondary_indexes + ann_indices = list( + filter( + lambda index: isinstance(index, VectorSearchIndex), secondary_indexes ) + ) + ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_ANN( + table_name, + dialect, + secondary_indexes=list(ann_indices), + ) + + knn_indices = filter( + lambda index: isinstance(index, SecondaryIndex), secondary_indexes + ) + ddl_statements += SpannerVectorStore._generate_secondary_indices_ddl_KNN( + table_name, + embedding_column, + dialect, + secondary_indexes=list(knn_indices), + ) return ddl_statements @@ -462,7 +485,7 @@ def _generate_create_table_sql( primary_key, dialect=DatabaseDialect.GOOGLE_STANDARD_SQL, ): - create_table_statement = f"CREATE TABLE {table_name} (\n" + create_table_statement = f"CREATE TABLE IF NOT EXISTS {table_name} (\n" if not isinstance(id_column, TableColumn): if dialect == DatabaseDialect.POSTGRESQL: @@ -485,7 +508,7 @@ def _generate_create_table_sql( ) else: embedding_column = TableColumn( - embedding_column, "ARRAY", is_null=True + embedding_column, "ARRAY", is_null=True ) configs = [id_column, content_column, embedding_column] @@ -503,6 +526,9 @@ def _generate_create_table_sql( # Append column name and data type column_sql = f" {column_config.name} {column_config.type}" + if column_config.vector_length and column_config.vector_length >= 1: + column_sql += f"(vector_length=>{column_config.vector_length})" + # Add nullable constraint if specified if not column_config.is_null: column_sql += " NOT NULL" @@ -522,6 +548,7 @@ def _generate_create_table_sql( + ")" ) + # print(create_table_statement) return create_table_statement @staticmethod @@ -559,6 +586,9 @@ def _generate_secondary_indices_ddl_ANN( f"ANN is only supported for the GoogleSQL dialect not {dialect}. File an issue on Github?" ) + if not secondary_indexes: + return [] + secondary_index_ddl_statements = [] for secondary_index in secondary_indexes: @@ -858,6 +888,13 @@ def _insert_data(self, records, columns_to_insert): values=records, ) + def add_ann_rows( + self, data: List[Tuple], id_column_index: int, columns=Dict[str, str] + ) -> List[str]: + self._insert_data(data, columns) + ids = list(map(lambda row: row[id_column_index], data)) + return ids + def add_documents( self, documents: List[Document], @@ -1079,6 +1116,9 @@ def _query_ANN( return sql + def _get_rows_by_similarity_search_ann(): + pass + def _get_rows_by_similarity_search_knn( self, embedding: List[float], diff --git a/tests/integration/test_spanner_vector_store.py b/tests/integration/test_spanner_vector_store.py index 7a16fc4..f00bc62 100644 --- a/tests/integration/test_spanner_vector_store.py +++ b/tests/integration/test_spanner_vector_store.py @@ -22,19 +22,21 @@ import pytest from langchain_google_spanner.vector_store import ( # type: ignore + AlgoKind, DistanceStrategy, QueryParameters, SpannerVectorStore, TableColumn, - AlgoKind, + VectorSearchIndex, ) project_id = os.environ["PROJECT_ID"] instance_id = os.environ["INSTANCE_ID"] google_database = os.environ["GOOGLE_DATABASE"] -pg_database = os.environ["PG_DATABASE"] +pg_database = os.environ.get("PG_DATABASE", None) +zone = os.environ.get("GOOGLE_DATABASE_ZONE", "us-west2") table_name = "test_table" + str(uuid.uuid4()).replace("-", "_") -table_name_ANN = table_name + "_ANN" +table_name_ANN = "products" OPERATION_TIMEOUT_SECONDS = 240 @@ -394,53 +396,230 @@ def test_spanner_vector_search_data4(self, setup_database): class TestSpannerVectorStoreGoogleSQL_ANN: @pytest.fixture(scope="class") def setup_database(self, client): + """ + CREATE TABLE products ( + categoryId INT64 NOT NULL, + productId INT64 NOT NULL, + productName STRING(MAX) NOT NULL, + productDescription STRING(MAX) NOT NULL, + productDescriptionEmbedding ARRAY(vector_length=>728), + createTime TIMESTAMP NOT NULL OPTIONS ( + allow_commit_timestamp = true + ), + inventoryCount INT64 NOT NULL, + priceInCents INT64, + ) PRIMARY KEY(categoryId, productId); + """ + distance_strategy = DistanceStrategy.COSINE SpannerVectorStore.init_vector_store_table( instance_id=instance_id, database_id=google_database, table_name=table_name_ANN, - id_column="row_id", + id_column=TableColumn("productId", type="INT64"), metadata_columns=[ - TableColumn(name="metadata", type="JSON", is_null=True), - TableColumn(name="title", type="STRING(MAX)", is_null=False), + TableColumn(name="categoryId", type="INT64", is_null=False), + TableColumn(name="productName", type="STRING(MAX)", is_null=False), + TableColumn( + name="productDescription", type="STRING(MAX)", is_null=False + ), + TableColumn( + name="productDescriptionEmbedding", + type="ARRAY", + vector_length=758, + is_null=True, + ), + TableColumn(name="inventoryCount", type="INT64", is_null=False), + TableColumn(name="priceInCents", type="INT64", is_null=True), + ], + secondary_indexes=[ + VectorSearchIndex( + index_name="ProductDescriptionEmbeddingIndex", + columns=["productDescriptionEmbedding"], + nullable_column=True, + num_branches=1000, + tree_depth=3, + index_type=distance_strategy, + num_leaves=100000, + ), ], - kind=AlgoKind.ANN, ) - loader = HNLoader("https://news.ycombinator.com/item?id=34817881") + raw_data = [ + ( + 1, + 1, + "Cymbal Helios Helmet", + "Safety meets style with the Cymbal children's bike helmet. Its lightweight design, superior ventilation, and adjustable fit ensure comfort and protection on every ride. Stay bright and keep your child safe under the sun with Cymbal Helios!", + 100, + 10999, + ), + ( + 1, + 2, + "Cymbal Sprout", + "Let their cycling journey begin with the Cymbal Sprout, the ideal balance bike for beginning riders ages 2-4 years. Its lightweight frame, low seat height, and puncture-proof tires promote stability and confidence as little ones learn to balance and steer. Watch them sprout into cycling enthusiasts with Cymbal Sprout!", + 10, + 13999, + ), + ( + 1, + 3, + "Cymbal Spark Jr.", + "Light, vibrant, and ready for adventure, the Spark Jr. is the perfect first bike for young riders (ages 5-8). Its sturdy frame, easy-to-use brakes, and puncture-resistant tires inspire confidence and endless playtime. Let the spark of cycling ignite with Cymbal!", + 34, + 13900, + ), + ( + 1, + 4, + "Cymbal Summit", + "Conquering trails is a breeze with the Summit mountain bike. Its lightweight aluminum frame, responsive suspension, and powerful disc brakes provide exceptional control and comfort for experienced bikers navigating rocky climbs or shredding downhill. Reach new heights with Cymbal Summit!", + 0, + 79999, + ), + ( + 1, + 5, + "Cymbal Breeze", + "Cruise in style and embrace effortless pedaling with the Breeze electric bike. Its whisper-quiet motor and long-lasting battery let you conquer hills and distances with ease. Enjoy scenic rides, commutes, or errands with a boost of confidence from Cymbal Breeze!", + 72, + 129999, + ), + ( + 1, + 6, + "Cymbal Trailblazer Backpack", + "Carry all your essentials in style with the Trailblazer backpack. Its water-resistant material, multiple compartments, and comfortable straps keep your gear organized and accessible, allowing you to focus on the adventure. Blaze new trails with Cymbal Trailblazer!", + 24, + 7999, + ), + ( + 1, + 7, + "Cymbal Phoenix Lights", + "See and be seen with the Phoenix bike lights. Powerful LEDs and multiple light modes ensure superior visibility, enhancing your safety and enjoyment during day or night rides. Light up your journey with Cymbal Phoenix!", + 87, + 3999, + ), + ( + 1, + 8, + "Cymbal Windstar Pump", + "Flat tires are no match for the Windstar pump. Its compact design, lightweight construction, and high-pressure capacity make inflating tires quick and effortless. Get back on the road in no time with Cymbal Windstar!", + 36, + 24999, + ), + ( + 1, + 9, + "Cymbal Odyssey Multi-Tool", + "Be prepared for anything with the Odyssey multi-tool. This handy gadget features essential tools like screwdrivers, hex wrenches, and tire levers, keeping you ready for minor repairs and adjustments on the go. Conquer your journey with Cymbal Odyssey!", + 52, + 999, + ), + ( + 1, + 10, + "Cymbal Nomad Water Bottle", + "Stay hydrated on every ride with the Nomad water bottle. Its sleek design, BPA-free construction, and secure lock lid make it the perfect companion for staying refreshed and motivated throughout your adventures. Hydrate and explore with Cymbal Nomad!", + 42, + 1299, + ), + ] - embeddings = FakeEmbeddings(size=3) + columns = [ + "categoryId", + "productId", + "productName", + "productDescription", + "createTime", + "inventoryCount", + "priceInCents", + ] - yield loader, embeddings + model_ddl_statements = [ + f""" + CREATE MODEL EmbeddingsModel INPUT( + content STRING(MAX), + ) OUTPUT( + embeddings STRUCT, values ARRAY>, + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/{zone}/publishers/google/models/text-embedding-004' + ) + """, + f""" + CREATE MODEL LLMModel INPUT( + prompt STRING(MAX), + ) OUTPUT( + content STRING(MAX), + ) REMOTE OPTIONS ( + endpoint = '//aiplatform.googleapis.com/projects/{project_id}/locations/{zone}/publishers/google/models/gemini-pro', + default_batch_size = 1 + ) + """, + """ + UPDATE products p1 + SET productDescriptionEmbedding = + ( + SELECT embeddings.values from ML.PREDICT( + MODEL EmbeddingsModel, + (SELECT productDescription as content FROM products p2 where p2.productId=p1.productId) + ) + ) + WHERE categoryId=1; + """, + ] + database = client.instance(instance_id).database(google_database) + + def create_models(): + operation = database.update_ddl(model_ddl_statements) + return operation.result(OPERATION_TIMEOUT_SECONDS) + + def get_embeddings(self): + sql = """SELECT embeddings.values FROM ML.PREDICT( + MODEL EmbeddingsModel, + (SELECT "I'd like to buy a starter bike for my 3 year old child" as content) + )""" + + with database.snapshot() as snapshot: + res = snapshot.execute_sql(sql) + return list(res) + + yield raw_data, columns, create_models, get_embeddings print("\nPerforming GSQL cleanup after each ANN test...") - database = client.instance(instance_id).database(google_database) - operation = database.update_ddl([f"DROP TABLE IF EXISTS {table_name_ANN}"]) + operation = database.update_ddl( + [ + f"DROP TABLE IF EXISTS {table_name_ANN}", + "DROP MODEL IF EXISTS EmbeddingsModel", + "DROP MODEL IF EXISTS LLMModel", + "DROP Index IF EXISTS ProductDescriptionEmbeddingIndex", + ] + ) operation.result(OPERATION_TIMEOUT_SECONDS) # Code to perform teardown after each test goes here print("\nGSQL Cleanup complete.") - def test_add_data1(self, setup_database): - loader, embeddings = setup_database + def test_ann_add_data1(self, setup_database): + raw_data, columns, create_models, get_embeddings = setup_database + + # Retrieve embeddings using ML_PREDICT. + embeddings = get_embeddings() + print("embeddings", embeddings) db = SpannerVectorStore( instance_id=instance_id, database_id=google_database, table_name=table_name_ANN, - id_column="row_id", + id_column="categoryId", ignore_metadata_columns=[], embedding_service=embeddings, metadata_json_column="metadata", - kind=AlgoKind.ANN, ) - docs = loader.load() - ids = [str(uuid.uuid4()) for _ in range(len(docs))] - ids_row_inserted = db.add_documents(documents=docs, ids=ids) - assert ids == ids_row_inserted - - def test_add_data2(self, setup_database): + def test_ann_add_data2(self, setup_database): loader, embeddings = setup_database db = SpannerVectorStore( @@ -451,7 +630,6 @@ def test_add_data2(self, setup_database): ignore_metadata_columns=[], embedding_service=embeddings, metadata_json_column="metadata", - kind=AlgoKind.ANN, ) texts = [ @@ -482,7 +660,6 @@ def test_delete_data(self, setup_database): ignore_metadata_columns=[], embedding_service=embeddings, metadata_json_column="metadata", - kind=AlgoKind.ANN, ) docs = loader.load() @@ -499,7 +676,6 @@ def test_search_data1(self, setup_database): # ignore_metadata_columns=[], # embedding_service=embeddings, # metadata_json_column="metadata", - # kind=AlgoKind.ANN, # ) # docs = db.similarity_search( # "Testing the langchain integration with spanner", k=2 @@ -518,7 +694,6 @@ def test_search_data2(self, setup_database): # ignore_metadata_columns=[], # embedding_service=embeddings, # metadata_json_column="metadata", - # kind=AlgoKind.ANN, # ) # embeds = embeddings.embed_query( # "Testing the langchain integration with spanner" @@ -542,7 +717,6 @@ def test_search_data3(self, setup_database): # distance_strategy=DistanceStrategy.COSINE, # max_staleness=datetime.timedelta(seconds=15), # ), - # kind=AlgoKind.ANN, # ) # # docs = db.similarity_search( @@ -566,7 +740,6 @@ def test_search_data4(self, setup_database): # distance_strategy=DistanceStrategy.COSINE, # max_staleness=datetime.timedelta(seconds=15), # ), - # kind=AlgoKind.ANN, # ) # docs = db.max_marginal_relevance_search( # "Testing the langchain integration with spanner", k=3 diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 31a0b03..0cbe555 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -74,7 +74,7 @@ def test_distance_function_raises_exception_if_unknown(self): self.sem.getDistanceFunction(strategy) -class TestSpannerVectorStore_KNN(unittest.TestCase): +class TestSpannerVectorStore(unittest.TestCase): def test_generate_create_table_sql(self): got = SpannerVectorStore._generate_create_table_sql( "users", From 3789da5149babd149994ae307729dfcad962a06b Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 27 Jan 2025 13:06:00 +0200 Subject: [PATCH 12/13] Make VectorSearchIndex inherit from SecondaryIndex --- src/langchain_google_spanner/vector_store.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/langchain_google_spanner/vector_store.py b/src/langchain_google_spanner/vector_store.py index fb80763..10d9837 100644 --- a/src/langchain_google_spanner/vector_store.py +++ b/src/langchain_google_spanner/vector_store.py @@ -103,13 +103,11 @@ def __post_init__(self): @dataclass -class VectorSearchIndex: +class VectorSearchIndex(SecondaryIndex): """ The index for use with Approximate Nearest Neighbor (ANN) vector search. """ - index_name: str - columns: list[str] # Each column passed in must have ARRAY type. num_leaves: int num_branches: int tree_depth: int From 5a0be1b339a493bfca7a74e2d07875512827a964 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 27 Jan 2025 20:26:37 +0200 Subject: [PATCH 13/13] Update copyright header year for new file as requested by headercheck --- tests/unit/test_vectore_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_vectore_store.py b/tests/unit/test_vectore_store.py index 0cbe555..4c06c74 100644 --- a/tests/unit/test_vectore_store.py +++ b/tests/unit/test_vectore_store.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.