Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement Approximate Nearest Neighbor support for DDL (CREATE TABLE, CREATE VECTOR INDEX) #124

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
8 changes: 7 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Use a vector store to store embedded data and perform vector search.

.. code-block:: python

from langchain_google_sapnner import SpannerVectorstore
from langchain_google_spanner import SpannerVectorstore
from langchain.embeddings import VertexAIEmbeddings

embeddings_service = VertexAIEmbeddings(model_name="textembedding-gecko@003")
Expand Down Expand Up @@ -206,3 +206,9 @@ Disclaimer

This is not an officially supported Google product.


Limitations
----------

* Approximate Nearest Neighbors (ANN) strategies are only supported for the GoogleSQL dialect
* ANN's `ALTER VECTOR INDEX` is not supported by [Google Cloud Spanner](https://cloud.google.com/spanner/docs/find-approximate-nearest-neighbors#limitations)
98 changes: 93 additions & 5 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,34 @@

import os
import pathlib
import shutil
from pathlib import Path
from typing import Optional
import shutil
from typing import List, Optional

import nox

DEFAULT_PYTHON_VERSION = "3.10"
CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()
FLAKE8_VERSION = "flake8==6.1.0"
BLACK_VERSION = "black[jupyter]==23.7.0"
ISORT_VERSION = "isort==5.11.0"
LINT_PATHS = ["src", "tests", "noxfile.py"]


nox.options.sessions = [
"docs",
"blacken",
"docfx",
"docs",
"format",
"lint",
"unit",
]

# Error if a python version is missing
nox.options.error_on_missing_interpreters = True


@nox.session(python="3.10")
@nox.session(python=DEFAULT_PYTHON_VERSION)
def docs(session):
"""Build the docs for this library."""

Expand Down Expand Up @@ -71,7 +80,7 @@ def docs(session):
)


@nox.session(python="3.10")
@nox.session(python=DEFAULT_PYTHON_VERSION)
def docfx(session):
"""Build the docfx yaml files for this library."""

Expand Down Expand Up @@ -115,3 +124,82 @@ def docfx(session):
os.path.join("docs", ""),
os.path.join("docs", "_build", "html", ""),
)


@nox.session(python=DEFAULT_PYTHON_VERSION)
def lint(session):
"""Run linters.

Returns a failure if the linters find linting errors or
sufficiently serious code quality issues.
"""
session.install(FLAKE8_VERSION, BLACK_VERSION)
session.run(
"black",
"--check",
*LINT_PATHS,
)
session.run("flake8", "google", "tests")


@nox.session(python=DEFAULT_PYTHON_VERSION)
def lint_setup_py(session):
"""Verify that setup.py is valid (including an RST check)."""
session.install("docutils", "pygments")
session.run("python", "setup.py", "check", "--restructuredtext", "--strict")


@nox.session(python=DEFAULT_PYTHON_VERSION)
def blacken(session):
session.install(BLACK_VERSION)
session.run(
"black",
*LINT_PATHS,
)


@nox.session(python=DEFAULT_PYTHON_VERSION)
def format(session):
session.install(BLACK_VERSION, ISORT_VERSION)
# Sort imports in strict alphabetical order.
session.run(
"isort",
"--fss",
*LINT_PATHS,
)
session.run(
"black",
*LINT_PATHS,
)


@nox.session(python=DEFAULT_PYTHON_VERSION)
def unit(session):
install_unittest_dependencies(session)
session.run(
"py.test",
"--quiet",
os.path.join("tests", "unit"),
)


UNIT_TEST_STANDARD_DEPENDENCIES = [
"mock",
"pytest",
]
UNIT_TEST_DEPENDENCIES: List[str] = []


def install_unittest_dependencies(session, *constraints):
standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES
session.install(*standard_deps, *constraints)
session.run(
"pip",
"install",
"--no-compile", # To ensure no byte recompliation which is usually super slow
"-q",
"--disable-pip-version-check", # Avoid the slow version check
".",
"-r",
"requirements.txt",
)
1 change: 0 additions & 1 deletion src/langchain_google_spanner/graph_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def _call(
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:

intermediate_steps: List = []

"""Generate gql statement, uses it to look up in db and answer question."""
Expand Down
20 changes: 10 additions & 10 deletions src/langchain_google_spanner/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from __future__ import annotations

from abc import abstractmethod
import re
import string
from abc import abstractmethod
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union

from google.cloud import spanner
Expand Down Expand Up @@ -211,9 +211,9 @@ def from_nodes(name: str, nodes: List[Node]) -> ElementSchema:
for k, v in n.properties.items()
}
)
node.types[ElementSchema.NODE_KEY_COLUMN_NAME] = (
TypeUtility.value_to_param_type(nodes[0].id)
)
node.types[
ElementSchema.NODE_KEY_COLUMN_NAME
] = TypeUtility.value_to_param_type(nodes[0].id)
return node

@staticmethod
Expand Down Expand Up @@ -264,12 +264,12 @@ def from_edges(name: str, edges: List[Relationship]) -> ElementSchema:
for k, v in e.properties.items()
}
)
edge.types[ElementSchema.NODE_KEY_COLUMN_NAME] = (
TypeUtility.value_to_param_type(edges[0].source.id)
)
edge.types[ElementSchema.TARGET_NODE_KEY_COLUMN_NAME] = (
TypeUtility.value_to_param_type(edges[0].target.id)
)
edge.types[
ElementSchema.NODE_KEY_COLUMN_NAME
] = TypeUtility.value_to_param_type(edges[0].source.id)
edge.types[
ElementSchema.TARGET_NODE_KEY_COLUMN_NAME
] = TypeUtility.value_to_param_type(edges[0].target.id)

edge.source = NodeReference(
edges[0].source.type,
Expand Down
2 changes: 1 addition & 1 deletion src/langchain_google_spanner/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
import datetime
import json
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Union

from google.cloud.spanner import Client, KeySet # type: ignore
Expand Down
Loading