Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/2629/vector type switch #2637

Merged
merged 5 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add schema to `Template`
- Low-code form builder for frontend
- Add snowflake vector search engine
- Add a meta-datatype `Vector` to handle different databackend requirements

#### Bug Fixes

Expand Down
4 changes: 3 additions & 1 deletion plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):
self.overwrite = False
self._setup(conn)

self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'}
jieguangzhou marked this conversation as resolved.
Show resolved Hide resolved

def _setup(self, conn):
self.dialect = getattr(conn, "name", "base")
self.db_helper = get_db_helper(self.dialect)
Expand Down Expand Up @@ -292,4 +294,4 @@ def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None
"""
from superduper.misc.auto_schema import infer_schema

return infer_schema(data, identifier=identifier, ibis=True)
return infer_schema(data, identifier=identifier)
8 changes: 4 additions & 4 deletions plugins/ibis/superduper_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ibis.expr.datatypes import dtype
from superduper import CFG
from superduper.components.datatype import (
Artifact,
DataType,
BaseDataType,
File,
LazyArtifact,
LazyFile,
Expand All @@ -23,7 +24,6 @@ def _convert_field_type_to_ibis_type(field_type: FieldType):
ibis_type = "String"
else:
ibis_type = field_type.identifier

return dtype(ibis_type)


Expand All @@ -39,12 +39,12 @@ def convert_schema_to_fields(schema: Schema):
for k, v in schema.fields.items():
if isinstance(v, FieldType):
fields[k] = _convert_field_type_to_ibis_type(v)
elif not isinstance(v, DataType):
elif not isinstance(v, BaseDataType):
fields[k] = v.identifier
else:
if v.encodable_cls in SPECIAL_ENCODABLES_FIELDS:
fields[k] = dtype(SPECIAL_ENCODABLES_FIELDS[v.encodable_cls])
else:
fields[k] = v.bytes_encoding
fields[k] = CFG.bytes_encoding

return fields
4 changes: 4 additions & 0 deletions plugins/mongodb/superduper_mongodb/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):

self._db = self.conn[self.name]

self.datatype_presets = {
'vector': 'superduper.components.datatype.NativeVector'
}

def reconnect(self):
"""Reconnect to mongodb store."""
# Reconnect to database.
Expand Down
4 changes: 2 additions & 2 deletions plugins/mongodb/superduper_mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,10 @@ def _get_schema(self):
continue
fields[key] = output_table.schema.fields[key]

from superduper.components.datatype import DataType
from superduper.components.datatype import BaseDataType
from superduper.components.schema import Schema

fields = {k: v for k, v in fields.items() if isinstance(v, DataType)}
fields = {k: v for k, v in fields.items() if isinstance(v, BaseDataType)}

return Schema(f"_tmp:{self.table}", fields=fields)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies = [
"boto3>=1.16",
"dill>=0.3.6",
"loguru>=0.7.2",
"loki-logger-handler>=0.1.1",
"loki-logger-handler==1.0.0",
"networkx>=2.8.8",
"requests>=2.22", # lower bound from openai and boto3
"tqdm>=4.64.1",
Expand Down
1 change: 1 addition & 0 deletions superduper/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ def _prepare_documents(self):
schema = table.schema
except FileNotFoundError:
pass

documents = [
r.encode(schema) if isinstance(r, Document) else r for r in documents
]
Expand Down
15 changes: 15 additions & 0 deletions superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ class RestConfig(BaseConfig):
config: t.Optional[str] = None


@dc.dataclass
class DataTypePresets(BaseConfig):
"""Paths of default types of data.

Overrides DataBackend.datatype_presets.

:param vector: BaseDataType to encode vectors.
"""

vector: str | None = None


@dc.dataclass
class Config(BaseConfig):
"""The data class containing all configurable superduper values.
Expand All @@ -148,6 +160,7 @@ class Config(BaseConfig):
:param log_level: The severity level of the logs
:param logging_type: The type of logging to use
:param force_apply: Whether to force apply the configuration
:param datatype_presets: Presets to be applied for default types of data
:param bytes_encoding: The encoding of bytes in the data backend
:param auto_schema: Whether to automatically create the schema.
If True, the schema will be created if it does not exist.
Expand Down Expand Up @@ -180,6 +193,8 @@ class Config(BaseConfig):

force_apply: bool = False

datatype_presets: DataTypePresets = dc.field(default_factory=DataTypePresets)

bytes_encoding: BytesEncoding = BytesEncoding.BYTES
auto_schema: bool = True
json_native: bool = True
Expand Down
4 changes: 2 additions & 2 deletions superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from superduper.base.variables import _replace_variables
from superduper.components.component import Component
from superduper.components.datatype import (
BaseDataType,
Blob,
DataType,
Encodable,
FileItem,
Native,
Expand Down Expand Up @@ -598,7 +598,7 @@ def _schema_decode(
decoded = {}
for k, value in data.items():
field = schema.fields.get(k)
if not isinstance(field, DataType):
if not isinstance(field, BaseDataType):
decoded[k] = value
continue

Expand Down
109 changes: 102 additions & 7 deletions superduper/components/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
import re
import typing as t
from abc import abstractmethod
from functools import cached_property
from importlib import import_module

import dill
import numpy

from superduper import CFG
from superduper.backends.base.artifacts import (
Expand Down Expand Up @@ -168,23 +171,77 @@ def check(data: t.Any) -> bool:

@staticmethod
@abstractmethod
def create(data: t.Any) -> "DataType":
def create(data: t.Any) -> "BaseDataType":
"""Create a DataType for the data.

:param data: The data to create the DataType for
"""
raise NotImplementedError


class DataType(Component):
class BaseDataType(Component):
"""Base class for datatype.

:param shape: size of vector
"""

type_id: t.ClassVar[str] = 'datatype'
# TODO this can just be an integer
shape: t.Optional[int] = None

@abstractmethod
def encode_data(self, item, info: t.Optional[t.Dict] = None):
"""Decode the item as `bytes`.

:param item: The item to decode.
:param info: The optional information dictionary.
"""

@abstractmethod
def decode_data(self, item, info: t.Optional[t.Dict] = None):
"""Decode the item from bytes.

:param item: The item to decode.
:param info: The optional information dictionary.
"""

def encode_data_with_identifier(self, item, info: t.Optional[t.Dict] = None):
b = self.encode_data(item=item, info=info)
if isinstance(b, bytes):
return b, hashlib.sha1(b).hexdigest()
else:
return b, hashlib.sha1(str(b).encode()).hexdigest()


class NativeVector(BaseDataType):
"""Datatype for encoding vectors which are supported natively by databackend.

:param dtype: Datatype of array to encode.
"""

dtype: str = 'float64'

def __post_init__(self, db, artifacts):
self.encodable_cls = Native
return super().__post_init__(db, artifacts)

def encode_data(self, item, info=None):
if isinstance(item, numpy.ndarray):
item = item.tolist()
return item

def decode_data(self, item, info=None):
return numpy.array(item).astype(self.dtype)
jieguangzhou marked this conversation as resolved.
Show resolved Hide resolved


class DataType(BaseDataType):
"""A data type component that defines how data is encoded and decoded.

:param encoder: A callable that converts an encodable object of this
encoder to bytes.
:param decoder: A callable that converts bytes to an encodable object
of this encoder.
:param info: An optional information dictionary.
:param shape: The shape of the data.
:param directory: The directory to store file types.
:param encodable: The type of encodable object ('encodable',
'lazy_artifact', or 'file').
Expand All @@ -194,12 +251,10 @@ class DataType(Component):
:param media_type: The media type.
"""

type_id: t.ClassVar[str] = 'datatype'
encoder: t.Optional[t.Callable] = None # not necessary if encodable is file
decoder: t.Optional[t.Callable] = None
info: t.Optional[t.Dict] = None
shape: t.Optional[t.Sequence] = None
directory: t.Optional[str] = None
info: t.Optional[t.Dict] = None # TODO deprecate
directory: t.Optional[str] = None # TODO needed?
jieguangzhou marked this conversation as resolved.
Show resolved Hide resolved
encodable: str = 'encodable'
bytes_encoding: t.Optional[str] = CFG.bytes_encoding
intermediate_type: t.Optional[str] = IntermediateType.BYTES
Expand Down Expand Up @@ -334,6 +389,7 @@ def encode_torch_state_dict(module, info):
return buffer.getvalue()


# TODO migrate to torch plugin
class DecodeTorchStateDict:
"""Torch state dictionary decoder.

Expand Down Expand Up @@ -450,6 +506,8 @@ class Blob(Leaf):
bytes: bytes


# TODO this is no longer stricly needed, since we now encode
# directly with `Schema`
class Encodable(_BaseEncodable):
"""Class for encoding non-Python datatypes to the database.

Expand Down Expand Up @@ -799,3 +857,40 @@ def get_serializer(
'dill_lazy': dill_lazy,
'file_lazy': file_lazy,
}


class Vector(BaseDataType):
"""Vector meta-datatype for encoding vectors ready for search.

:param dtype: Datatype of encoded arrays.
"""

identifier: str = ''
dtype: str = 'float64'

def __post_init__(self, db, artifacts):
self.identifier = f'vector[{self.shape[0]}]'
return super().__post_init__(db, artifacts)

@property
def encodable_cls(self):
return self.datatype_impl.encodable_cls

@cached_property
def datatype_impl(self):
if isinstance(CFG.datatype_presets.vector, str):
type_: str = CFG.datatype_presets.vector
else:
type_: str = self.db.databackend.datatype_presets['vector']
jieguangzhou marked this conversation as resolved.
Show resolved Hide resolved
module = '.'.join(type_.split('.')[:-1])
cls = type_.split('.')[-1]
datatype = getattr(import_module(module), cls)
if inspect.isclass(datatype):
datatype = datatype('tmp', dtype=self.dtype)
return datatype

def encode_data(self, item, info: t.Optional[t.Dict] = None):
return self.datatype_impl.encode_data(item=item, info=info)

def decode_data(self, item, info: t.Optional[t.Dict] = None):
return self.datatype_impl.decode_data(item=item, info=info)
14 changes: 7 additions & 7 deletions superduper/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from superduper.base.leaf import Leaf
from superduper.components.component import Component
from superduper.components.datatype import DataType
from superduper.components.datatype import BaseDataType, DataType
from superduper.misc.reference import parse_reference
from superduper.misc.special_dicts import SuperDuperFlatEncode

Expand Down Expand Up @@ -54,7 +54,7 @@ def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)

for k, v in self.fields.items():
if isinstance(v, (DataType, FieldType)):
if isinstance(v, (BaseDataType, FieldType)):
continue

try:
Expand All @@ -67,18 +67,18 @@ def __post_init__(self, db, artifacts):
@cached_property
def encoded_types(self):
"""List of fields of type DataType."""
return [k for k, v in self.fields.items() if isinstance(v, DataType)]
return [k for k, v in self.fields.items() if isinstance(v, BaseDataType)]

@cached_property
def trivial(self):
"""Determine if the schema contains only trivial fields."""
return not any([isinstance(v, DataType) for v in self.fields.values()])
return not any([isinstance(v, BaseDataType) for v in self.fields.values()])

@property
def encoders(self):
"""An iterable to list DataType fields."""
for v in self.fields.values():
if isinstance(v, DataType):
if isinstance(v, BaseDataType):
yield v

@property
Expand All @@ -99,7 +99,7 @@ def encode_data(self, out, builds, blobs, files, leaves_to_keep=()):
:param files: Files.
"""
for k, field in self.fields.items():
if not isinstance(field, DataType):
if not isinstance(field, BaseDataType):
continue

if k not in out:
Expand Down Expand Up @@ -137,7 +137,7 @@ def __call__(self, data: dict[str, t.Any]) -> dict[str, t.Any]:

encoded_data = {}
for k, v in data.items():
if k in self.fields and isinstance(self.fields[k], DataType):
if k in self.fields and isinstance(self.fields[k], BaseDataType):
field_encoder = self.fields[k]
assert callable(field_encoder)
encoded_data.update({k: field_encoder(v)})
Expand Down
1 change: 1 addition & 0 deletions superduper/components/vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def dimensions(self) -> int:

try:
assert dt.shape is not None, msg
assert isinstance(dt.shape, (tuple, list))
return dt.shape[-1]
except IndexError as e:
raise Exception(
Expand Down
4 changes: 2 additions & 2 deletions superduper/ext/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing as t

from .encoder import array
from .encoder import Array, array

requirements: t.List = []

__all__ = ['array']
__all__ = ['array', 'Array']
Loading
Loading