Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/snowflake issues #2691

Merged
merged 9 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Deprecate vanilla `DataType`
- Remove `_Encodable` from project
- Connect to Snowflake using the incluster oauth token
- Add postprocess in apibase model.

#### New Features & Functionality

Expand All @@ -37,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix the issue where the trigger fails in custom components.
- Fix serialization between vector-search client and vector-search backend with `to_dict`
- Fix the bug in the update diff check that replaces uuids
- Fix snowflake vector search issues.

## [0.4.0](https://github.com/superduper-io/superduper/compare/0.4.0...0.3.0]) (2024-Nov-02)

Expand Down
2 changes: 1 addition & 1 deletion plugins/ibis/superduper_ibis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .data_backend import IbisDataBackend as DataBackend
from .query import IbisQuery

__version__ = "0.4.5"
__version__ = "0.4.6"

__all__ = ["IbisQuery", "DataBackend"]
7 changes: 5 additions & 2 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,13 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):
self.overwrite = False
self._setup(conn)

self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'}

if uri.startswith('snowflake://'):
self.bytes_encoding = 'base64'

self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'}
self.datatype_presets = {
'vector': 'superduper.components.datatype.NativeVector'
}

def _setup(self, conn):
self.dialect = getattr(conn, "name", "base")
Expand Down
6 changes: 3 additions & 3 deletions plugins/ibis/superduper_ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def _model_update_impl(
for output, source_id in zip(outputs, ids):
d = {
"_source": str(source_id),
f"{CFG.output_prefix}{predict_id}": output.x
if isinstance(output, _Encodable)
else output,
f"{CFG.output_prefix}{predict_id}": (
output.x if isinstance(output, _Encodable) else output
),
"id": str(uuid.uuid4()),
}
documents.append(Document(d))
Expand Down
11 changes: 7 additions & 4 deletions plugins/ibis/superduper_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from ibis.expr.datatypes import dtype
from superduper.components.datatype import (
BaseDataType,
File,
)
from superduper.components.datatype import BaseDataType, File, Vector
from superduper.components.schema import ID, FieldType, Schema

SPECIAL_ENCODABLES_FIELDS = {
Expand Down Expand Up @@ -39,6 +36,12 @@ def convert_schema_to_fields(schema: Schema):
if schema.db.databackend.bytes_encoding == 'base64'
else 'bytes'
)
elif isinstance(v, Vector):
fields[k] = dtype('json')

elif v.encodable == 'native':
fields[k] = dtype(v.dtype)

else:
fields[k] = dtype('str')

Expand Down
2 changes: 1 addition & 1 deletion plugins/openai/superduper_openai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .model import OpenAIChatCompletion, OpenAIEmbedding

__version__ = "0.4.1"
__version__ = "0.4.2"

__all__ = 'OpenAIChatCompletion', 'OpenAIEmbedding'
11 changes: 9 additions & 2 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,21 @@ def predict(self, X: str):
e = self.syncClient.embeddings.create(
input=X, model=self.model, **self.predict_kwargs
)
return numpy.array(e.data[0].embedding).astype('float32')

out = numpy.array(e.data[0].embedding).astype('float32')
if self.postprocess is not None:
out = self.postprocess(out)
return out

@retry
def _predict_a_batch(self, texts: t.List[t.Dict]):
out = self.syncClient.embeddings.create(
input=texts, model=self.model, **self.predict_kwargs
)
return [numpy.array(r.embedding).astype('float32') for r in out.data]
out = [numpy.array(r.embedding).astype('float32') for r in out.data]
if self.postprocess is not None:
out = list(map(self.postprocess, out))
return out


class OpenAIChatCompletion(_OpenAI):
Expand Down
7 changes: 5 additions & 2 deletions plugins/snowflake/plugin_test/test_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from superduper import CFG, superduper
from superduper.components.vector_index import vector
from superduper.components.datatype import Vector

from superduper_snowflake.vector_search import SnowflakeVectorSearcher

Expand All @@ -21,8 +21,11 @@
@pytest.mark.skipif(DO_SKIP, reason='Only snowflake deployments relevant.')
def test_basic_snowflake_search():
CFG.vector_search_engine = 'snowflake'
CFG.force_apply = True

db = superduper()
d1 = vector(shape=[300])

d1 = Vector(shape=[300])
build_vector_index(db, n=10, list_embeddings=True, vector_datatype=d1, measure='l2')

vector_index = "vector_index"
Expand Down
2 changes: 1 addition & 1 deletion plugins/snowflake/superduper_snowflake/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .vector_search import SnowflakeVectorSearcher as VectorSearcher

__version__ = "0.4.0"
__version__ = "0.4.1"

__all__ = [
"VectorSearcher",
Expand Down
9 changes: 6 additions & 3 deletions plugins/snowflake/superduper_snowflake/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from snowflake.snowpark import Session
from superduper import CFG
from superduper.vector_search.base import BaseVectorSearcher, VectorItem
from superduper.backends.base.vector_search import BaseVectorSearcher, VectorItem

if t.TYPE_CHECKING:
from superduper.components.vector_index import VectorIndex
Expand Down Expand Up @@ -46,8 +46,8 @@ def __init__(

self.measure = measure
self.dimensions = dimensions

super().__init__(identifier=identifier, dimensions=dimensions, measure=measure)
self._cache = {}
self._db = None

@classmethod
def create_session(cls, vector_search_uri):
Expand Down Expand Up @@ -162,3 +162,6 @@ def find_nearest_from_array(self, h, n=100, within_ids=None):
ids = [row["_source"] for row in result_list]
scores = [-row["distance".upper()] for row in result_list]
return ids, scores

def initialize(self):
"""Initialize vector search."""
5 changes: 2 additions & 3 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,11 @@ def _insert(
if auto_schema and self.cfg.auto_schema:
schema = self._auto_create_table(insert.table, insert.documents).schema

timeout = 5
timeout = 60

import time

start = time.time()

exists = False
while time.time() - start < timeout:
try:
Expand All @@ -296,7 +295,7 @@ def _insert(
exists = True
except AssertionError as e:
logging.warn(str(e))
time.sleep(0.25)
time.sleep(1)
continue
break

Expand Down
11 changes: 9 additions & 2 deletions superduper/components/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,15 @@ class NativeVector(BaseVector):
"""Datatype for encoding vectors which are supported as list by databackend."""

encodable: t.ClassVar[str] = 'native'
dtype: str = 'float'

def encode_data(self, item):
if isinstance(item, numpy.ndarray):
item = item.tolist()
return item

def decode_data(self, item):
# TODO:
return numpy.array(item).astype(self.dtype)


Expand All @@ -110,7 +112,7 @@ class Vector(BaseVector):

def __post_init__(self, db):
self.identifier = f'vector[{self.shape[0]}]'
return super().__post_init__(db)
super().__post_init__(db)

@property
def encodable(self):
Expand All @@ -122,6 +124,7 @@ def datatype_impl(self):
type_: str = CFG.datatype_presets.vector
else:
type_: str = self.db.databackend.datatype_presets['vector']

module = '.'.join(type_.split('.')[:-1])
cls = type_.split('.')[-1]
datatype = getattr(import_module(module), cls)
Expand All @@ -137,9 +140,13 @@ def decode_data(self, item):


class JSON(BaseDataType):
"""Datatype for encoding vectors which are supported natively by databackend."""
"""Datatype for encoding vectors which are supported natively by databackend.

:param dtype: Datatype of encoded arrays.
"""

encodable: t.ClassVar[str] = 'native'
dtype: str = 'str'

def __post_init__(self, db):
return super().__post_init__(db)
Expand Down
3 changes: 2 additions & 1 deletion superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,10 +1110,12 @@ class APIBaseModel(Model):

:param model: The Model to use, e.g. ``'text-embedding-ada-002'``
:param max_batch_size: Maximum batch size.
:param postprocess: Postprocess function to use on the output of the API request
"""

model: t.Optional[str] = None
max_batch_size: int = 8
postprocess: t.Optional[t.Callable] = None

def __post_init__(self, db, example):
super().__post_init__(db, example)
Expand Down Expand Up @@ -1149,7 +1151,6 @@ class APIModel(APIBaseModel):
"""

url: str
postprocess: t.Optional[t.Callable] = None

@property
def inputs(self):
Expand Down
3 changes: 3 additions & 0 deletions superduper/rest/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def _add_templates(self, db):

import os

if t is None:
continue

if os.path.exists(t):
from superduper import Template

Expand Down
24 changes: 19 additions & 5 deletions superduper/rest/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,33 @@ def db_upload(raw: bytes = File(...), db: 'Datalayer' = DatalayerDependency()):
return {"component": component, "artifacts": blob_objects}

def _process_db_apply(db, component, id: str | None = None):
def _apply():
nonlocal component
component = Document.decode(component, db=db).unpack()
db.apply(component, force=True)

if id:
log_file = f"/tmp/{id}.log"
with redirect_stdout_to_file(log_file):
db.apply(component, force=True)
try:
_apply()
except Exception as e:
logging.error(f'Exception during application apply :: {e}')
raise
else:
db.apply(component, force=True)
try:
_apply()

except Exception as e:
logging.error(f'Exception during application apply :: {e}')
raise

@app.add('/describe_tables')
def describe_tables(db: 'Datalayer' = DatalayerDependency()):
return db.databackend.list_tables_or_collections()

@app.add('/db/apply', method='post')
async def db_apply(
def db_apply(
info: t.Dict,
background_tasks: BackgroundTasks,
id: str | None = 'test',
Expand All @@ -170,8 +184,8 @@ async def db_apply(
info['_variables']['output_prefix'] = CFG.output_prefix
info['_variables']['databackend'] = db.databackend.backend_name

component = Document.decode(info, db=db).unpack()
background_tasks.add_task(_process_db_apply, db, component, id)
# info = Document.decode(info, db=db).unpack()
background_tasks.add_task(_process_db_apply, db, info, id)
return {'status': 'ok'}

import subprocess
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Loading