Skip to content

Commit

Permalink
Fix SQL tests by adding dtype parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Nov 21, 2024
1 parent befa6a0 commit 1f9334b
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 23 deletions.
3 changes: 2 additions & 1 deletion plugins/ibis/plugin_test/test_databackend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from test.utils.database import databackend as db_utils

import pytest
from superduper import CFG
from test.utils.database import databackend as db_utils

from superduper_ibis.data_backend import IbisDataBackend

Expand Down
3 changes: 2 additions & 1 deletion plugins/ibis/plugin_test/test_query.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from test.utils.setup.fake_data import add_listeners, add_models, add_random_data

import numpy as np
import pytest
from superduper.base.document import Document
from superduper.components.schema import Schema
from superduper.components.table import Table
from test.utils.setup.fake_data import add_listeners, add_models, add_random_data


def test_serialize_table():
Expand Down
4 changes: 2 additions & 2 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):
self.overwrite = False
self._setup(conn)

self.datatype_presets = {'vector': 'superduper.ext.numpy.Array'}
self.datatype_presets = {'vector': 'superduper.ext.numpy.encoder.Array'}

def _setup(self, conn):
self.dialect = getattr(conn, "name", "base")
Expand Down Expand Up @@ -294,4 +294,4 @@ def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None
"""
from superduper.misc.auto_schema import infer_schema

return infer_schema(data, identifier=identifier, ibis=True)
return infer_schema(data, identifier=identifier)
8 changes: 4 additions & 4 deletions plugins/ibis/superduper_ibis/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ibis.expr.datatypes import dtype
from superduper import CFG
from superduper.components.datatype import (
Artifact,
DataType,
BaseDataType,
File,
LazyArtifact,
LazyFile,
Expand All @@ -23,7 +24,6 @@ def _convert_field_type_to_ibis_type(field_type: FieldType):
ibis_type = "String"
else:
ibis_type = field_type.identifier

return dtype(ibis_type)


Expand All @@ -39,12 +39,12 @@ def convert_schema_to_fields(schema: Schema):
for k, v in schema.fields.items():
if isinstance(v, FieldType):
fields[k] = _convert_field_type_to_ibis_type(v)
elif not isinstance(v, DataType):
elif not isinstance(v, BaseDataType):
fields[k] = v.identifier
else:
if v.encodable_cls in SPECIAL_ENCODABLES_FIELDS:
fields[k] = dtype(SPECIAL_ENCODABLES_FIELDS[v.encodable_cls])
else:
fields[k] = v.bytes_encoding
fields[k] = CFG.bytes_encoding

return fields
3 changes: 2 additions & 1 deletion plugins/mongodb/plugin_test/test_artifact_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from test.utils.database import artifact_store as artifact_store_utils

import pytest
from superduper import CFG
from test.utils.database import artifact_store as artifact_store_utils

from superduper_mongodb.data_backend import MongoDataBackend

Expand Down
3 changes: 2 additions & 1 deletion plugins/mongodb/plugin_test/test_databackend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from test.utils.database import databackend as db_utils

import pytest
from superduper import CFG
from test.utils.database import databackend as db_utils

from superduper_mongodb.data_backend import MongoDataBackend

Expand Down
3 changes: 2 additions & 1 deletion plugins/mongodb/plugin_test/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from test.utils.database import metadata as metadata_utils

import pytest
from superduper import CFG
from test.utils.database import metadata as metadata_utils

from superduper_mongodb.metadata import MongoMetaDataStore

Expand Down
8 changes: 4 additions & 4 deletions plugins/mongodb/plugin_test/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import random

import numpy as np
import pytest
from superduper.base.document import Document
from test.utils.setup.fake_data import (
add_listeners,
add_models,
add_random_data,
add_vector_index,
)

import numpy as np
import pytest
from superduper.base.document import Document


def get_new_data(n=10, update=False):
data = []
Expand Down
26 changes: 21 additions & 5 deletions superduper/components/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,24 @@ def encode_data_with_identifier(self, item, info: t.Optional[t.Dict] = None):


class NativeVector(BaseDataType):
"""Datatype for encoding vectors which are supported natively by databackend."""
"""Datatype for encoding vectors which are supported natively by databackend.
:param dtype: Datatype of array to encode.
"""

dtype: str = 'float64'

def __post_init__(self, db, artifacts):
self.encodable_cls = Native
return super().__post_init__(db, artifacts)

def encode_data(self, item, info=None):
if isinstance(item, numpy.ndarray):
item = item.tolist()
return item

def decode_data(self, item, info=None):
return numpy.array(item)
return numpy.array(item).astype(self.dtype)


class DataType(BaseDataType):
Expand Down Expand Up @@ -851,15 +860,22 @@ def get_serializer(


class Vector(BaseDataType):
"""Vector meta-datatype for encoding vectors ready for search."""
"""Vector meta-datatype for encoding vectors ready for search.
:param dtype: Datatype of encoded arrays.
"""

identifier: str = ''
dtype: str = 'float64'

def __post_init__(self, db, artifacts):
self.identifier = f'vector[{self.shape[0]}]'
self.encodable_cls = Native
return super().__post_init__(db, artifacts)

@property
def encodable_cls(self):
return self.datatype_impl.encodable_cls

@cached_property
def datatype_impl(self):
if isinstance(CFG.datatype_presets.vector, str):
Expand All @@ -870,7 +886,7 @@ def datatype_impl(self):
cls = type_.split('.')[-1]
datatype = getattr(import_module(module), cls)
if inspect.isclass(datatype):
datatype = datatype('tmp')
datatype = datatype('tmp', dtype=self.dtype)
return datatype

def encode_data(self, item, info: t.Optional[t.Dict] = None):
Expand Down
11 changes: 10 additions & 1 deletion superduper/ext/numpy/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import numpy

from superduper.components.datatype import BaseDataType, DataType, DataTypeFactory
from superduper.components.datatype import (
BaseDataType,
DataType,
DataTypeFactory,
Encodable,
)
from superduper.ext.utils import str_shape
from superduper.misc.annotations import component

Expand Down Expand Up @@ -55,6 +60,10 @@ class Array(BaseDataType):

dtype: str = 'float64'

def __post_init__(self, db, artifacts):
self.encodable_cls = Encodable
return super().__post_init__(db, artifacts)

def encode_data(self, item, info=None):
encoder = EncodeArray(self.dtype)
return encoder(item)
Expand Down
2 changes: 1 addition & 1 deletion superduper/misc/auto_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def create(data: t.Any) -> BaseDataType | FieldType:
:param data: The data object
"""
return Vector(shape=(len(data),))
return Vector(shape=(len(data),), dtype=str(data.dtype))


class JsonDataTypeFactory(DataTypeFactory):
Expand Down
2 changes: 1 addition & 1 deletion test/utils/usecase/graph_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def build_graph_listener(db: "Datalayer"):

db["documents"].insert(data).execute()

data = db['documents'].find().tolist()
data = db['documents'].select().tolist()

assert isinstance(data[0]['z'], np.ndarray)

Expand Down

0 comments on commit 1f9334b

Please sign in to comment.