Skip to content

Commit

Permalink
Add linting
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Nov 21, 2024
1 parent 0630d44 commit ac895f6
Show file tree
Hide file tree
Showing 23 changed files with 62 additions and 56 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add schema to `Template`
- Low-code form builder for frontend
- Add snowflake vector search engine
- Add a meta-datatype `Vector` to handle different databackend requirements

#### Bug Fixes

Expand Down
3 changes: 1 addition & 2 deletions plugins/ibis/plugin_test/test_databackend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from test.utils.database import databackend as db_utils

import pytest
from superduper import CFG
from test.utils.database import databackend as db_utils

from superduper_ibis.data_backend import IbisDataBackend

Expand Down
3 changes: 1 addition & 2 deletions plugins/ibis/plugin_test/test_query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from test.utils.setup.fake_data import add_listeners, add_models, add_random_data

import numpy as np
import pytest
from superduper.base.document import Document
from superduper.components.schema import Schema
from superduper.components.table import Table
from test.utils.setup.fake_data import add_listeners, add_models, add_random_data


def test_serialize_table():
Expand Down
4 changes: 1 addition & 3 deletions plugins/ibis/superduper_ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def __init__(self, uri: str, flavour: t.Optional[str] = None):
self.overwrite = False
self._setup(conn)

self.datatype_presets = {
'vector': 'superduper.ext.numpy.Array'
}
self.datatype_presets = {'vector': 'superduper.ext.numpy.Array'}

def _setup(self, conn):
self.dialect = getattr(conn, "name", "base")
Expand Down
3 changes: 1 addition & 2 deletions plugins/mongodb/plugin_test/test_artifact_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from test.utils.database import artifact_store as artifact_store_utils

import pytest
from superduper import CFG
from test.utils.database import artifact_store as artifact_store_utils

from superduper_mongodb.data_backend import MongoDataBackend

Expand Down
3 changes: 1 addition & 2 deletions plugins/mongodb/plugin_test/test_databackend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from test.utils.database import databackend as db_utils

import pytest
from superduper import CFG
from test.utils.database import databackend as db_utils

from superduper_mongodb.data_backend import MongoDataBackend

Expand Down
3 changes: 1 addition & 2 deletions plugins/mongodb/plugin_test/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from test.utils.database import metadata as metadata_utils

import pytest
from superduper import CFG
from test.utils.database import metadata as metadata_utils

from superduper_mongodb.metadata import MongoMetaDataStore

Expand Down
8 changes: 4 additions & 4 deletions plugins/mongodb/plugin_test/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import random

import numpy as np
import pytest
from superduper.base.document import Document
from test.utils.setup.fake_data import (
add_listeners,
add_models,
add_random_data,
add_vector_index,
)

import numpy as np
import pytest
from superduper.base.document import Document


def get_new_data(n=10, update=False):
data = []
Expand Down
4 changes: 2 additions & 2 deletions plugins/mongodb/superduper_mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,10 @@ def _get_schema(self):
continue
fields[key] = output_table.schema.fields[key]

from superduper.components.datatype import DataType
from superduper.components.datatype import BaseDataType
from superduper.components.schema import Schema

fields = {k: v for k, v in fields.items() if isinstance(v, DataType)}
fields = {k: v for k, v in fields.items() if isinstance(v, BaseDataType)}

return Schema(f"_tmp:{self.table}", fields=fields)

Expand Down
1 change: 0 additions & 1 deletion superduper/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,6 @@ def _prepare_documents(self):
kwargs = self.parts[0][2]
schema = kwargs.pop('schema', None)


if schema is None:
try:
table = self.db.load('table', self.table)
Expand Down
1 change: 1 addition & 0 deletions superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class DataTypePresets(BaseConfig):
:param vector: BaseDataType to encode vectors.
"""

vector: str | None = None


Expand Down
1 change: 0 additions & 1 deletion superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from superduper.components.datatype import (
BaseDataType,
Blob,
DataType,
Encodable,
FileItem,
Native,
Expand Down
35 changes: 19 additions & 16 deletions superduper/components/datatype.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import base64
import dataclasses as dc
from functools import cached_property
import hashlib
from importlib import import_module
import inspect
import io
import json
import numpy
import os
import pickle
import re
import typing as t
from abc import ABC, abstractmethod
from abc import abstractmethod
from functools import cached_property
from importlib import import_module

import dill
import numpy

from superduper import CFG
from superduper.backends.base.artifacts import (
_construct_file_id_from_uri,
)
from superduper.base.config import BytesEncoding
from superduper.base.leaf import Leaf, import_item
from superduper.base.leaf import Leaf
from superduper.components.component import Component, ensure_initialized
from superduper.misc.annotations import component
from superduper.misc.hash import hash_path
Expand Down Expand Up @@ -171,7 +171,7 @@ def check(data: t.Any) -> bool:

@staticmethod
@abstractmethod
def create(data: t.Any) -> "DataType":
def create(data: t.Any) -> "BaseDataType":
"""Create a DataType for the data.
:param data: The data to create the DataType for
Expand All @@ -191,12 +191,12 @@ class BaseDataType(Component):

@abstractmethod
def encode_data(self, item, info: t.Optional[t.Dict] = None):
"""Decode the item as `bytes`
"""Decode the item as `bytes`.
:param item: The item to decode.
:param info: The optional information dictionary.
"""

@abstractmethod
def decode_data(self, item, info: t.Optional[t.Dict] = None):
"""Decode the item from bytes.
Expand All @@ -207,19 +207,21 @@ def decode_data(self, item, info: t.Optional[t.Dict] = None):

def encode_data_with_identifier(self, item, info: t.Optional[t.Dict] = None):
b = self.encode_data(item=item, info=info)
if type(b) == bytes:
if isinstance(b, bytes):
return b, hashlib.sha1(b).hexdigest()
else:
return b, hashlib.sha1(str(b).encode()).hexdigest()


class NativeVector(BaseDataType):
def encode_data(self, item, info = None):
"""Datatype for encoding vectors which are supported natively by databackend."""

def encode_data(self, item, info=None):
if isinstance(item, numpy.ndarray):
item = item.tolist()
return item
def decode_data(self, item, info = None):

def decode_data(self, item, info=None):
return numpy.array(item)


Expand All @@ -242,8 +244,8 @@ class DataType(BaseDataType):

encoder: t.Optional[t.Callable] = None # not necessary if encodable is file
decoder: t.Optional[t.Callable] = None
info: t.Optional[t.Dict] = None # TODO deprecate
directory: t.Optional[str] = None # TODO needed?
info: t.Optional[t.Dict] = None # TODO deprecate
directory: t.Optional[str] = None # TODO needed?
encodable: str = 'encodable'
bytes_encoding: t.Optional[str] = CFG.bytes_encoding
intermediate_type: t.Optional[str] = IntermediateType.BYTES
Expand All @@ -261,6 +263,7 @@ def __post_init__(self, db, artifacts):
self.encodable_cls = _ENCODABLES[self.encodable]
else:
import importlib

self.encodable_cls = importlib.import_module(
'.'.join(self.encodable.split('.')[:-1])
).__dict__[self.encodable.split('.')[-1]]
Expand Down Expand Up @@ -848,6 +851,7 @@ def get_serializer(


class Vector(BaseDataType):
"""Vector meta-datatype for encoding vectors ready for search."""

identifier: str = ''

Expand All @@ -872,6 +876,5 @@ def datatype_impl(self):
def encode_data(self, item, info: t.Optional[t.Dict] = None):
return self.datatype_impl.encode_data(item=item, info=info)


def decode_data(self, item, info: t.Optional[t.Dict] = None):
return self.datatype_impl.decode_data(item=item, info=info)
return self.datatype_impl.decode_data(item=item, info=info)
8 changes: 4 additions & 4 deletions superduper/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,18 @@ def __post_init__(self, db, artifacts):
@cached_property
def encoded_types(self):
"""List of fields of type DataType."""
return [k for k, v in self.fields.items() if isinstance(v, DataType)]
return [k for k, v in self.fields.items() if isinstance(v, BaseDataType)]

@cached_property
def trivial(self):
"""Determine if the schema contains only trivial fields."""
return not any([isinstance(v, DataType) for v in self.fields.values()])
return not any([isinstance(v, BaseDataType) for v in self.fields.values()])

@property
def encoders(self):
"""An iterable to list DataType fields."""
for v in self.fields.values():
if isinstance(v, DataType):
if isinstance(v, BaseDataType):
yield v

@property
Expand Down Expand Up @@ -137,7 +137,7 @@ def __call__(self, data: dict[str, t.Any]) -> dict[str, t.Any]:

encoded_data = {}
for k, v in data.items():
if k in self.fields and isinstance(self.fields[k], DataType):
if k in self.fields and isinstance(self.fields[k], BaseDataType):
field_encoder = self.fields[k]
assert callable(field_encoder)
encoded_data.update({k: field_encoder(v)})
Expand Down
1 change: 1 addition & 0 deletions superduper/components/vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def dimensions(self) -> int:

try:
assert dt.shape is not None, msg
assert isinstance(dt.shape, tuple)
return dt.shape[-1]
except IndexError as e:
raise Exception(
Expand Down
2 changes: 1 addition & 1 deletion superduper/ext/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing as t

from .encoder import array, Array
from .encoder import Array, array

requirements: t.List = []

Expand Down
5 changes: 3 additions & 2 deletions superduper/ext/numpy/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ class Array(BaseDataType):
:param dtype: numpy native datatype
"""

dtype: str = 'float64'

def encode_data(self, item, info = None):
def encode_data(self, item, info=None):
encoder = EncodeArray(self.dtype)
return encoder(item)

def decode_data(self, item, info = None):
def decode_data(self, item, info=None):
shape = self.shape
if isinstance(shape, int):
shape = (self.shape,)
Expand Down
12 changes: 6 additions & 6 deletions superduper/misc/auto_schema.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import importlib
import numpy as np
import typing as t

import numpy as np

from superduper import CFG, logging
from superduper.base.exceptions import UnsupportedDatatype
from superduper.components.datatype import (
BaseDataType,
DataType,
DataTypeFactory,
_BaseEncodable,
Vector,
_BaseEncodable,
get_serializer,
json_serializer,
)
Expand Down Expand Up @@ -59,6 +60,7 @@ def infer_datatype(data: t.Any) -> t.Optional[t.Union[DataType, type]]:

try:
from bson import ObjectId

if isinstance(data, ObjectId):
return datatype
except ImportError:
Expand Down Expand Up @@ -143,7 +145,7 @@ def check(data: t.Any) -> bool:
return isinstance(data, np.ndarray) and len(data.shape) == 1

@staticmethod
def create(data: t.Any) -> DataType | FieldType:
def create(data: t.Any) -> BaseDataType | FieldType:
"""Create a JSON datatype.
:param data: The data object
Expand All @@ -167,7 +169,7 @@ def check(data: t.Any) -> bool:
return False

@staticmethod
def create(data: t.Any) -> DataType | FieldType:
def create(data: t.Any) -> BaseDataType | FieldType:
"""Create a JSON datatype.
:param data: The data object
Expand All @@ -184,5 +186,3 @@ def create(data: t.Any) -> DataType | FieldType:

FACTORIES = DataTypeFactory.__subclasses__()
FACTORIES = sorted(FACTORIES, key=lambda x: 0 if x.__module__ == __name__ else 1)


7 changes: 4 additions & 3 deletions test/unittest/component/datatype/test_vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy

from superduper.components.datatype import Vector


def test_auto_detect_vector(db):
db.cfg.auto_schema = True
Expand All @@ -14,4 +12,7 @@ def test_auto_detect_vector(db):

impl = schema.fields['x'].datatype_impl

assert impl.__module__ + '.' + impl.__class__.__name__ == db.databackend.datatype_presets['vector']
assert (
impl.__module__ + '.' + impl.__class__.__name__
== db.databackend.datatype_presets['vector']
)
6 changes: 5 additions & 1 deletion test/unittest/component/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,15 @@ def _test(X, docs):
# X = {'a': 'x', 'b': 'y'}
# _test(X, docs)


import numpy


@pytest.fixture
def object_model():
return ObjectModel('test', object=lambda x: numpy.array(x) + 1, signature='singleton')
return ObjectModel(
'test', object=lambda x: numpy.array(x) + 1, signature='singleton'
)


def test_object_model_predict(object_model):
Expand Down
Loading

0 comments on commit ac895f6

Please sign in to comment.