From 1b7e33b059f0c3cebc3427139dd83244dddf3d6e Mon Sep 17 00:00:00 2001 From: TheDude Date: Sat, 21 Dec 2024 19:26:01 +0530 Subject: [PATCH] [PLUGINS] Bump Version [snowflake | ibis] --- plugins/ibis/superduper_ibis/__init__.py | 2 +- .../superduper_snowflake/__init__.py | 2 +- .../superduper_snowflake/vector_search.py | 51 ++++++++++++++----- 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/plugins/ibis/superduper_ibis/__init__.py b/plugins/ibis/superduper_ibis/__init__.py index fb1f30ab0..92a43d292 100644 --- a/plugins/ibis/superduper_ibis/__init__.py +++ b/plugins/ibis/superduper_ibis/__init__.py @@ -1,6 +1,6 @@ from .data_backend import IbisDataBackend as DataBackend from .query import IbisQuery -__version__ = "0.4.6" +__version__ = "0.4.7" __all__ = ["IbisQuery", "DataBackend"] diff --git a/plugins/snowflake/superduper_snowflake/__init__.py b/plugins/snowflake/superduper_snowflake/__init__.py index 5a440233b..6f0622979 100644 --- a/plugins/snowflake/superduper_snowflake/__init__.py +++ b/plugins/snowflake/superduper_snowflake/__init__.py @@ -1,6 +1,6 @@ from .vector_search import SnowflakeVectorSearcher as VectorSearcher -__version__ = "0.4.3" +__version__ = "0.4.4" __all__ = [ "VectorSearcher", diff --git a/plugins/snowflake/superduper_snowflake/vector_search.py b/plugins/snowflake/superduper_snowflake/vector_search.py index d40fb9e78..12b3d9ff6 100644 --- a/plugins/snowflake/superduper_snowflake/vector_search.py +++ b/plugins/snowflake/superduper_snowflake/vector_search.py @@ -1,3 +1,4 @@ +import os import re import typing as t @@ -55,24 +56,46 @@ def create_session(cls, vector_search_uri): :param vector_search_uri: Connection URI. """ - pattern = r"snowflake://(?P[^:]+):(?P[^@]+)@(?P[^/]+)/(?P[^/]+)/(?P[^/]+)" - match = re.match(pattern, vector_search_uri) + if vector_search_uri == 'snowflake://': + host = os.environ['SNOWFLAKE_HOST'] + port = int(os.environ['SNOWFLAKE_PORT']) + account = os.environ['SNOWFLAKE_ACCOUNT'] + token = open('/snowflake/session/token').read() + warehouse = os.environ['SNOWFLAKE_WAREHOUSE'] + database = os.environ['SNOWFLAKE_DATABASE'] + schema = os.environ['SUPERDUPER_DATA_SCHEMA'] - if match: connection_parameters = { - "user": match.group("user"), - "password": match.group("password"), - "account": match.group("account"), - "database": match.group("database"), - "schema": match.group("schema"), - # TODO: check warehouse - "warehouse": "base", + "token": token, + "account": account, + "database": database, + "schema": schema, + "warehouse": warehouse, + "authenticator": "oauth", + "port": port, + "host": host, } - session = Session.builder.configs(connection_parameters).create() - return session - else: - raise ValueError(f"URI `{vector_search_uri}` is invalid!") + pattern = r"snowflake://(?P[^:]+):(?P[^@]+)@(?P[^/]+)/(?P[^/]+)/(?P[^/]+)" + match = re.match(pattern, vector_search_uri) + schema = match.group("schema") + database = match.group("database") + if match: + connection_parameters = { + "user": match.group("user"), + "password": match.group("password"), + "account": match.group("account"), + "database": match.group("database"), + "schema": match.group("schema"), + # TODO: check warehouse + "warehouse": "base", + } + + else: + raise ValueError(f"URI `{vector_search_uri}` is invalid!") + + session = Session.builder.configs(connection_parameters).create() + return session def __len__(self): pass