Skip to content

Commit

Permalink
Apply some changes from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Jan 21, 2025
1 parent 0a1982a commit 7e5279a
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 73 deletions.
96 changes: 49 additions & 47 deletions src/langchain_google_spanner/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,6 @@ class SecondaryIndex:
index_name: str
columns: list[str]
storing_columns: Optional[list[str]] = None
num_leaves: Optional[int] = None # Only necessary for ANN
nullable_column: Optional[bool] = False # Only necessary for ANN
num_branches: Optional[int] = None # Only necessary for ANN
tree_depth: Optional[int] = None # Only necessary for ANN
index_type: Optional[DistanceStrategy] = None # Only necessary for ANN

def __post_init__(self):
# Check if column_name is None after initialization
Expand All @@ -102,24 +97,39 @@ def __post_init__(self):
raise ValueError("Index Columns can't be None")


@dataclass
class VectorSearchIndex:
index_name: str
columns: list[str]
num_leaves: int
num_branches: int
tree_depth: int
index_type: DistanceStrategy
nullable_column: bool = False

def __post_init__(self):
if self.index_name is None:
raise ValueError("index_name must be set")

if len(self.columns) == 0:
raise ValueError("columns must be set")

ok_tree_depth = self.tree_depth in (2, 3)
if not ok_tree_depth:
raise ValueError("tree_depth must be either 2 or 3")


class DistanceStrategy(Enum):
"""
Enum for distance calculation strategies.
"""

COSINE = 1
EUCLIDEIAN = 2
EUCLIDEAN = 2
DOT_PRODUCT = 3

def __str__(self):
return DISTANCE_STRATEGY_STRING[self]


DISTANCE_STRATEGY_STRING = {
DistanceStrategy.COSINE: "COSINE",
DistanceStrategy.EUCLIDEIAN: "EUCLIDEIAN",
DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT",
}
return self.name


class DialectSemantics(ABC):
Expand All @@ -128,7 +138,7 @@ class DialectSemantics(ABC):
"""

@abstractmethod
def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str:
def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str:
"""
Abstract method to get the distance function based on the provided distance strategy.
Expand All @@ -155,22 +165,18 @@ def getDeleteDocumentsValueParameters(self, columns, values) -> Dict[str, Any]:
)


_GOOGLE_DISTANCE_ALGO_NAMES = {
# Maps between distance strategy enums and the appropriate vector search index name.
GOOGLE_DIALECT_DISTANCE_FUCNTIONS = {
DistanceStrategy.COSINE: "COSINE_DISTANCE",
DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT",
DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN_DISTANCE",
DistanceStrategy.EUCLIDEAN: "EUCLIDEAN_DISTANCE",
}

# Maps between distance strategy and the appropriate ANN search function name.
distance_strategy_to_ANN_function = {
DistanceStrategy.COSINE: "APPROX_COSINE_DISTANCE",
DistanceStrategy.DOT_PRODUCT: "APPROX_DOT_PRODUCT",
DistanceStrategy.EUCLIDEIAN: "APPROX_EUCLIDEAN_DISTANCE",
}

_GOOGLE_ALGO_INDEX_NAME = {
DistanceStrategy.COSINE: "COSINE",
DistanceStrategy.DOT_PRODUCT: "DOT_PRODUCT",
DistanceStrategy.EUCLIDEIAN: "EUCLIDEAN",
DistanceStrategy.EUCLIDEAN: "APPROX_EUCLIDEAN_DISTANCE",
}


Expand All @@ -179,8 +185,8 @@ class GoogleSqlSemnatics(DialectSemantics):
Implementation of dialect semantics for Google SQL.
"""

def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str:
return _GOOGLE_DISTANCE_ALGO_NAMES.get(distance_strategy, "EUCLIDEAN")
def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str:
return GOOGLE_DIALECT_DISTANCE_FUCNTIONS.get(distance_strategy, "EUCLIDEAN")

def getDeleteDocumentsParameters(self, columns) -> Tuple[str, Any]:
where_clause_condition = " AND ".join(
Expand All @@ -201,10 +207,11 @@ def getIndexDistanceType(self, distance_strategy) -> str:
return value


_PG_DISTANCE_ALGO_NAMES = {
# Maps between DistanceStrategy and the expected PostgreSQL distance equivalent.
PG_DIALECT_DISTANCE_FUNCTIONS = {
DistanceStrategy.COSINE: "spanner.cosine_distance",
DistanceStrategy.DOT_PRODUCT: "spanner.dot_product",
DistanceStrategy.EUCLIDEIAN: "spanner.euclidean_distance",
DistanceStrategy.EUCLIDEAN: "spanner.euclidean_distance",
}


Expand All @@ -213,8 +220,8 @@ class PGSqlSemnatics(DialectSemantics):
Implementation of dialect semantics for PostgreSQL.
"""

def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEIAN) -> str:
name = _PG_DISTANCE_ALGO_NAMES.get(distance_strategy, None)
def getDistanceFunction(self, distance_strategy=DistanceStrategy.EUCLIDEAN) -> str:
name = PG_DIALECT_DISTANCE_FUNCTIONS.get(distance_strategy, None)
if name is None:
raise Exception(
"Unsupported PostgreSQL distance strategy: {}".format(distance_strategy)
Expand Down Expand Up @@ -254,7 +261,7 @@ class QueryParameters:

class NearestNeighborsAlgorithm(Enum):
"""
Enum for nearest neighbors search algorithms.
Enum for k-nearest neighbors search algorithms.
"""

EXACT_NEAREST_NEIGHBOR = 1
Expand All @@ -263,7 +270,7 @@ class NearestNeighborsAlgorithm(Enum):
def __init__(
self,
algorithm=NearestNeighborsAlgorithm.EXACT_NEAREST_NEIGHBOR,
distance_strategy=DistanceStrategy.EUCLIDEIAN,
distance_strategy=DistanceStrategy.EUCLIDEAN,
read_timestamp: Optional[datetime.datetime] = None,
min_read_timestamp: Optional[datetime.datetime] = None,
max_staleness: Optional[datetime.timedelta] = None,
Expand Down Expand Up @@ -303,10 +310,6 @@ def __init__(
self.staleness = {key: value}


DEFAULT_ANN_TREE_DEPTH = 2
ANN_ACCEPTABLE_TREE_DEPTHS = (2, 3)


class AlgoKind(Enum):
KNN = 0
ANN = 1
Expand Down Expand Up @@ -341,8 +344,8 @@ def init_vector_store_table(
metadata_columns: Optional[List[TableColumn]] = None,
primary_key: Optional[str] = None,
vector_size: Optional[int] = None,
secondary_indexes: Optional[List[SecondaryIndex]] = None,
kind: AlgoKind = None,
secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None,
kind: AlgoKind = AlgoKind.KNN,
) -> bool:
"""
Initialize the vector store new table in Google Cloud Spanner.
Expand All @@ -357,6 +360,7 @@ def init_vector_store_table(
- embedding_column (str): The name of the embedding column. Defaults to EMBEDDING_COLUMN_NAME.
- metadata_columns (Optional[List[Tuple]]): List of tuples containing metadata column information. Defaults to None.
- vector_size (Optional[int]): The size of the vector. Defaults to None.
- kind (AlgoKind): Defines whether to use k-Nearest Neighbors or Approximate Nearest Neighbors. Defaults to kNN.
"""

client = client_with_user_agent(client, USER_AGENT_VECTOR_STORE)
Expand Down Expand Up @@ -400,7 +404,7 @@ def _generate_sql(
embedding_column,
column_configs,
primary_key,
secondary_indexes: Optional[List[SecondaryIndex]] = None,
secondary_indexes: Optional[List[SecondaryIndex | VectorSearchIndex]] = None,
kind: Optional[AlgoKind] = AlgoKind.KNN,
limit=None,
):
Expand Down Expand Up @@ -546,23 +550,21 @@ def _generate_secondary_indices_ddl_ANN(
):
if dialect != DatabaseDialect.GOOGLE_STANDARD_SQL:
raise Exception(
f"ANN is only supported for the GoogleSQL dialect not {dialect}"
f"ANN is only supported for the GoogleSQL dialect not {dialect}. File an issue on Github?"
)

secondary_index_ddl_statements = []

for secondary_index in secondary_indexes:
column_name = secondary_index.columns[0]
statement = f"CREATE VECTOR INDEX {secondary_index.index_name}\n\tON {table_name}({column_name})"
if secondary_index.nullable_column:
if getattr(secondary_index, "nullable_column", False):
statement += f"\n\tWHERE {column_name} IS NOT NULL"
options_segments = [f"distance_type='{secondary_index.index_type}'"]
if secondary_index.tree_depth > 0:
if getattr(secondary_index, "tree_depth", 0) > 0:
tree_depth = secondary_index.tree_depth
if tree_depth not in ANN_ACCEPTABLE_TREE_DEPTHS:
raise Exception(
f"tree_depth: {tree_depth} is not in the acceptable values: {ANN_ACCEPTABLE_TREE_DEPTHS}"
)
if tree_depth not in (2, 3):
raise Exception(f"tree_depth: {tree_depth} must be either 2 or 3")
options_segments.append(f"tree_depth={secondary_index.tree_depth}")

if secondary_index.num_branches > 0:
Expand Down Expand Up @@ -761,7 +763,7 @@ def _validate_table_schema(self, column_type_map, types, default_columns):
def _select_relevance_score_fn(self) -> Callable[[float], float]:
if self._query_parameters.distance_strategy == DistanceStrategy.COSINE:
return self._cosine_relevance_score_fn
elif self._query_parameters.distance_strategy == DistanceStrategy.EUCLIDEIAN:
elif self._query_parameters.distance_strategy == DistanceStrategy.EUCLIDEAN:
return self._euclidean_relevance_score_fn
else:
raise Exception(
Expand Down
Loading

0 comments on commit 7e5279a

Please sign in to comment.