diff --git a/plugins/ibis/plugin_test/test_databackend.py b/plugins/ibis/plugin_test/test_databackend.py index e6029db39..87551cdae 100644 --- a/plugins/ibis/plugin_test/test_databackend.py +++ b/plugins/ibis/plugin_test/test_databackend.py @@ -1,6 +1,7 @@ +from test.utils.database import databackend as db_utils + import pytest from superduper import CFG -from test.utils.database import databackend as db_utils from superduper_ibis.data_backend import IbisDataBackend diff --git a/plugins/ibis/plugin_test/test_query.py b/plugins/ibis/plugin_test/test_query.py index d2bc4d999..7a8c1bc92 100644 --- a/plugins/ibis/plugin_test/test_query.py +++ b/plugins/ibis/plugin_test/test_query.py @@ -1,9 +1,10 @@ +from test.utils.setup.fake_data import add_listeners, add_models, add_random_data + import numpy as np import pytest from superduper.base.document import Document from superduper.components.schema import Schema from superduper.components.table import Table -from test.utils.setup.fake_data import add_listeners, add_models, add_random_data def test_serialize_table(): diff --git a/plugins/ibis/superduper_ibis/data_backend.py b/plugins/ibis/superduper_ibis/data_backend.py index 5c5bf6d91..5dcb49eb6 100644 --- a/plugins/ibis/superduper_ibis/data_backend.py +++ b/plugins/ibis/superduper_ibis/data_backend.py @@ -69,7 +69,7 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None): self.overwrite = False self._setup(conn) - self.datatype_presets = {'vector': 'superduper.ext.numpy.Array'} + self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'} def _setup(self, conn): self.dialect = getattr(conn, "name", "base") @@ -294,4 +294,4 @@ def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None """ from superduper.misc.auto_schema import infer_schema - return infer_schema(data, identifier=identifier, ibis=True) + return infer_schema(data, identifier=identifier) diff --git a/plugins/ibis/superduper_ibis/utils.py b/plugins/ibis/superduper_ibis/utils.py index 37930eb60..e33efd263 100644 --- a/plugins/ibis/superduper_ibis/utils.py +++ b/plugins/ibis/superduper_ibis/utils.py @@ -1,7 +1,8 @@ from ibis.expr.datatypes import dtype +from superduper import CFG from superduper.components.datatype import ( Artifact, - DataType, + BaseDataType, File, LazyArtifact, LazyFile, @@ -23,7 +24,6 @@ def _convert_field_type_to_ibis_type(field_type: FieldType): ibis_type = "String" else: ibis_type = field_type.identifier - return dtype(ibis_type) @@ -39,12 +39,12 @@ def convert_schema_to_fields(schema: Schema): for k, v in schema.fields.items(): if isinstance(v, FieldType): fields[k] = _convert_field_type_to_ibis_type(v) - elif not isinstance(v, DataType): + elif not isinstance(v, BaseDataType): fields[k] = v.identifier else: if v.encodable_cls in SPECIAL_ENCODABLES_FIELDS: fields[k] = dtype(SPECIAL_ENCODABLES_FIELDS[v.encodable_cls]) else: - fields[k] = v.bytes_encoding + fields[k] = CFG.bytes_encoding return fields diff --git a/plugins/mongodb/plugin_test/test_artifact_store.py b/plugins/mongodb/plugin_test/test_artifact_store.py index 67d7a0dd8..2960c51ae 100644 --- a/plugins/mongodb/plugin_test/test_artifact_store.py +++ b/plugins/mongodb/plugin_test/test_artifact_store.py @@ -1,6 +1,7 @@ +from test.utils.database import artifact_store as artifact_store_utils + import pytest from superduper import CFG -from test.utils.database import artifact_store as artifact_store_utils from superduper_mongodb.data_backend import MongoDataBackend diff --git a/plugins/mongodb/plugin_test/test_databackend.py b/plugins/mongodb/plugin_test/test_databackend.py index f7f073ecf..15dbcdf87 100644 --- a/plugins/mongodb/plugin_test/test_databackend.py +++ b/plugins/mongodb/plugin_test/test_databackend.py @@ -1,6 +1,7 @@ +from test.utils.database import databackend as db_utils + import pytest from superduper import CFG -from test.utils.database import databackend as db_utils from superduper_mongodb.data_backend import MongoDataBackend diff --git a/plugins/mongodb/plugin_test/test_metadata.py b/plugins/mongodb/plugin_test/test_metadata.py index 9db969775..26adc4e9a 100644 --- a/plugins/mongodb/plugin_test/test_metadata.py +++ b/plugins/mongodb/plugin_test/test_metadata.py @@ -1,6 +1,7 @@ +from test.utils.database import metadata as metadata_utils + import pytest from superduper import CFG -from test.utils.database import metadata as metadata_utils from superduper_mongodb.metadata import MongoMetaDataStore diff --git a/plugins/mongodb/plugin_test/test_queries.py b/plugins/mongodb/plugin_test/test_queries.py index 1e074ac23..258222913 100644 --- a/plugins/mongodb/plugin_test/test_queries.py +++ b/plugins/mongodb/plugin_test/test_queries.py @@ -1,8 +1,4 @@ import random - -import numpy as np -import pytest -from superduper.base.document import Document from test.utils.setup.fake_data import ( add_listeners, add_models, @@ -10,6 +6,10 @@ add_vector_index, ) +import numpy as np +import pytest +from superduper.base.document import Document + def get_new_data(n=10, update=False): data = [] diff --git a/superduper/components/datatype.py b/superduper/components/datatype.py index 41507f342..09a39ff57 100644 --- a/superduper/components/datatype.py +++ b/superduper/components/datatype.py @@ -214,7 +214,16 @@ def encode_data_with_identifier(self, item, info: t.Optional[t.Dict] = None): class NativeVector(BaseDataType): - """Datatype for encoding vectors which are supported natively by databackend.""" + """Datatype for encoding vectors which are supported natively by databackend. + + :param dtype: Datatype of array to encode. + """ + + dtype: str = 'float64' + + def __post_init__(self, db, artifacts): + self.encodable_cls = Native + return super().__post_init__(db, artifacts) def encode_data(self, item, info=None): if isinstance(item, numpy.ndarray): @@ -222,7 +231,7 @@ def encode_data(self, item, info=None): return item def decode_data(self, item, info=None): - return numpy.array(item) + return numpy.array(item).astype(self.dtype) class DataType(BaseDataType): @@ -851,15 +860,22 @@ def get_serializer( class Vector(BaseDataType): - """Vector meta-datatype for encoding vectors ready for search.""" + """Vector meta-datatype for encoding vectors ready for search. + + :param dtype: Datatype of encoded arrays. + """ identifier: str = '' + dtype: str = 'float64' def __post_init__(self, db, artifacts): self.identifier = f'vector[{self.shape[0]}]' - self.encodable_cls = Native return super().__post_init__(db, artifacts) + @property + def encodable_cls(self): + return self.datatype_impl.encodable_cls + @cached_property def datatype_impl(self): if isinstance(CFG.datatype_presets.vector, str): @@ -870,7 +886,7 @@ def datatype_impl(self): cls = type_.split('.')[-1] datatype = getattr(import_module(module), cls) if inspect.isclass(datatype): - datatype = datatype('tmp') + datatype = datatype('tmp', dtype=self.dtype) return datatype def encode_data(self, item, info: t.Optional[t.Dict] = None): diff --git a/superduper/ext/numpy/encoder.py b/superduper/ext/numpy/encoder.py index d73fd7102..cc1a62bd9 100644 --- a/superduper/ext/numpy/encoder.py +++ b/superduper/ext/numpy/encoder.py @@ -2,7 +2,12 @@ import numpy -from superduper.components.datatype import BaseDataType, DataType, DataTypeFactory +from superduper.components.datatype import ( + BaseDataType, + DataType, + DataTypeFactory, + Encodable, +) from superduper.ext.utils import str_shape from superduper.misc.annotations import component @@ -55,6 +60,10 @@ class Array(BaseDataType): dtype: str = 'float64' + def __post_init__(self, db, artifacts): + self.encodable_cls = Encodable + return super().__post_init__(db, artifacts) + def encode_data(self, item, info=None): encoder = EncodeArray(self.dtype) return encoder(item) diff --git a/superduper/misc/auto_schema.py b/superduper/misc/auto_schema.py index 55c4b1148..831a7b4e5 100644 --- a/superduper/misc/auto_schema.py +++ b/superduper/misc/auto_schema.py @@ -150,7 +150,7 @@ def create(data: t.Any) -> BaseDataType | FieldType: :param data: The data object """ - return Vector(shape=(len(data),)) + return Vector(shape=(len(data),), dtype=str(data.dtype)) class JsonDataTypeFactory(DataTypeFactory): diff --git a/test/utils/usecase/graph_listener.py b/test/utils/usecase/graph_listener.py index 2bd3006c0..5c358a20e 100644 --- a/test/utils/usecase/graph_listener.py +++ b/test/utils/usecase/graph_listener.py @@ -38,7 +38,7 @@ def build_graph_listener(db: "Datalayer"): db["documents"].insert(data).execute() - data = db['documents'].find().tolist() + data = db['documents'].select().tolist() assert isinstance(data[0]['z'], np.ndarray)