Skip to content

Commit

Permalink
Deprecate signature as argument
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Nov 29, 2024
1 parent 14950b4 commit ef827c4
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Deprecate vanilla `DataType`
- Remove `_Encodable` from project
- Deprecate `signature` as parameter and auto-infer from `.predict`

#### New Features & Functionality

Expand Down
7 changes: 2 additions & 5 deletions superduper/components/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ class DocumentInput(Model):

spec: t.Union[str, t.List[str]]
identifier: str = '_input'
signature: Signature = 'singleton'

def __post_init__(self, db, example):
super().__post_init__(db, example)
Expand Down Expand Up @@ -232,12 +231,11 @@ class Graph(Model):
:param edges: Graph edges list.
:param input: Graph root node.
:param outputs: Graph output nodes.
:param signature: Graph signature.
Example:
-------
>> g = Graph(
>> identifier='simple-graph', input=model1, outputs=[model2], signature='*args'
>> identifier='simple-graph', input=model1, outputs=[model2],
>> )
>> g.connect(model1, model2)
>> assert g.predict(1) == [(4, 2)]
Expand All @@ -253,15 +251,14 @@ class Graph(Model):
)
input: Model
outputs: t.List[t.Union[str, Model]] = dc.field(default_factory=list)
signature: Signature = '*args,**kwargs'

def __post_init__(self, db, example):
self.G = nx.DiGraph()
self.nodes = {}
self.version = 0
self._db = None

self.signature = self.input.signature
self._signature = self.input.signature
if isinstance(self.outputs, list):
self.output_identifiers = [
o.identifier if isinstance(o, Model) else o for o in self.outputs
Expand Down
67 changes: 56 additions & 11 deletions superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def model(
output_schema: t.Optional[Schema] = None,
num_workers: int = 0,
example: t.Any | None = None,
signature: Signature = '*args,**kwargs',
):
"""Decorator to wrap a function with `ObjectModel`.
Expand All @@ -59,7 +58,6 @@ def model(
:param output_schema: Schema for the model outputs.
:param num_workers: Number of workers to use for parallel processing
:param example: Example to auto-determine the schema/ datatype.
:param signature: Signature for the model.
"""
if item is not None and (inspect.isclass(item) or callable(item)):
if inspect.isclass(item):
Expand All @@ -74,7 +72,6 @@ def object_model_factory(*args, **kwargs):
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)

return object_model_factory
Expand All @@ -88,7 +85,6 @@ def object_model_factory(*args, **kwargs):
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)
else:

Expand All @@ -105,7 +101,6 @@ def object_model_factory(*args, **kwargs):
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)

return object_model_factory
Expand All @@ -119,7 +114,6 @@ def object_model_factory(*args, **kwargs):
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)

return decorated_function
Expand Down Expand Up @@ -352,13 +346,32 @@ def __new__(mcls, name, bases, dct):
if 'init' in dct:
dct['init'] = init_decorator(dct['init'])
cls = super().__new__(mcls, name, bases, dct)

signature = inspect.signature(cls.predict)
pos = []
kw = []
for k in signature.parameters:
if k == 'self':
continue
if signature.parameters[k].default == inspect._empty:
pos.append(k)
else:
kw.append(k)
if len(pos) == 1 and not kw:
cls._signature = 'singleton'
elif pos and not kw:
cls._signature = '*args'
elif pos and kw:
cls._signature = '*args,**kwargs'
else:
assert not pos and kw
cls._signature = '**kwargs'
return cls


class Model(Component, metaclass=ModelMeta):
"""Base class for components which can predict.
:param signature: Model signature.
:param datatype: DataType instance.
:param output_schema: Output schema (mapping of encoders).
:param model_update_kwargs: The kwargs to use for model update.
Expand All @@ -378,7 +391,6 @@ class Model(Component, metaclass=ModelMeta):

breaks: t.ClassVar[t.Sequence] = ('trainer',)
type_id: t.ClassVar[str] = 'model'
signature: Signature = '*args,**kwargs'
datatype: EncoderArg = None
output_schema: t.Optional[Schema] = None
model_update_kwargs: t.Dict = dc.field(default_factory=dict)
Expand All @@ -400,6 +412,10 @@ def __post_init__(self, db, example):
if not self.identifier:
raise Exception('_Predictor identifier must be non-empty')

@property
def signature(self):
return self._signature

@property
def inputs(self) -> Inputs:
"""Instance of `Inputs` to represent model params."""
Expand Down Expand Up @@ -1039,6 +1055,10 @@ class ImportedModel(Model):
object: Leaf
method: t.Optional[str] = None

def __post_init__(self, db, example):
super().__post_init__(db, example)
self._inferred_signature = None

@staticmethod
def _infer_signature(object):
# find positional and key-word parameters from the object
Expand All @@ -1059,6 +1079,13 @@ def _infer_signature(object):
return '**kwargs'
return '*args,**kwargs'

@property
@ensure_initialized
def signature(self):
if self._inferred_signature is None:
self._inferred_signature = self._infer_signature(self.object)
return self._inferred_signature

@property
def outputs(self):
"""Get an instance of ``IndexableNode`` to index outputs."""
Expand Down Expand Up @@ -1197,6 +1224,7 @@ class QueryModel(Model):
:param preprocess: Preprocess callable
:param postprocess: Postprocess callable
:param select: query used to find data (can include `like`)
:param signature: signature to use
"""

preprocess: t.Optional[t.Callable] = None
Expand Down Expand Up @@ -1258,10 +1286,13 @@ class SequentialModel(Model):
models: t.List[Model]

def __post_init__(self, db, example):
self.signature = self.models[0].signature
self.datatype = self.models[-1].datatype
return super().__post_init__(db, example)

@property
def signature(self):
return self.models[0].signature

@property
def inputs(self) -> Inputs:
"""Instance of `Inputs` to represent model params."""
Expand Down Expand Up @@ -1291,7 +1322,22 @@ def predict(self, *args, **kwargs):
:param args: Positional arguments to predict on.
:param kwargs: Keyword arguments to predict on.
"""
return self.predict_batches([(args, kwargs)])[0]
for i, p in enumerate(self.models):
assert isinstance(p, Model), f'Expected `Model`, got {type(p)}'
if i == 0:
out = p.predict(*args, **kwargs)
else:
if p.signature == 'singleton':
out = p.predict(out)
elif p.signature == '*args':
out = p.predict(*out)
elif p.signature == '**kwargs':
out = p.predict(**out)
else:
msg = 'Model defines a predict with no free parameters'
assert p.signature == '*args,**kwargs', msg
out = p.predict(*out[0], **out[1])
return out

def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List:
"""Execute on series of data point defined in dataset.
Expand Down Expand Up @@ -1358,7 +1404,6 @@ class RAGModel(Model):
breaks: t.ClassVar[t.Sequence] = ('llm', 'prompt_template')

prompt_template: str
signature: str = 'singleton'
select: Query
key: str
llm: Model
Expand Down
8 changes: 3 additions & 5 deletions superduper/rest/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,9 @@ def db_remove(
@app.add('/db/show_template', method='get')
def db_show_template(
identifier: str,
type_id: str = 'template',
db: 'Datalayer' = DatalayerDependency(),
):
template: Template = db.load(type_id=type_id, identifier=identifier)
template: Template = db.load(type_id='template', identifier=identifier)
return template.form_template

@app.add('/db/edit', method='get')
Expand All @@ -248,10 +247,9 @@ def db_edit(
db: 'Datalayer' = DatalayerDependency(),
):
component = db.load(type_id, identifier)
template = db.load('template', component['build_template'])
template: Template = db.load(type_id=type_id, identifier=identifier)
template: Template = db.load('template', component.build_template)
form = template.form_template
form['_variables'] = component['build_variables']
form['_variables'] = component.build_variables
return form

@app.add('/db/metadata/show_jobs', method='get')
Expand Down
7 changes: 4 additions & 3 deletions test/unittest/base/test_datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def test_replace(db: Datalayer):
object=lambda x: x + 1,
identifier='m',
datatype=DataType(identifier='base'),
signature='singleton',
)
model.version = 0
db.apply(model)
Expand All @@ -482,7 +481,8 @@ def test_replace(db: Datalayer):
assert db.load('model', 'm').predict(1) == 2

new_model = ObjectModel(
object=lambda x: x + 2, identifier='m', signature='singleton'
object=lambda x: x + 2,
identifier='m',
)
new_model.version = 0
db.replace(new_model)
Expand All @@ -494,7 +494,8 @@ def test_replace(db: Datalayer):

# replace the last version of the model
new_model = ObjectModel(
object=lambda x: x + 3, identifier='m', signature='singleton'
object=lambda x: x + 3,
identifier='m',
)
new_model.version = 0
db.replace(new_model)
Expand Down
11 changes: 7 additions & 4 deletions test/unittest/component/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def model1(db):
def model_object(x):
return x + 1

model = ObjectModel(identifier='m1', object=model_object, signature='singleton')
model = ObjectModel(identifier='m1', object=model_object)
yield model


Expand Down Expand Up @@ -53,13 +53,17 @@ def model_object(x, y):

def test_simple_graph(model1, model2):
g = Graph(
identifier='simple-graph', input=model1, outputs=model2, signature='*args'
identifier='simple-graph',
input=model1,
outputs=model2,
)
g.connect(model1, model2)
assert g.predict(1) == (4, 2)

g = Graph(
identifier='simple-graph', input=model1, outputs=model2, signature='*args'
identifier='simple-graph',
input=model1,
outputs=model2,
)
g.connect(model1, model2)
assert g.predict_batches([1, 2, 3]) == [(4, 2), (5, 3), (6, 4)]
Expand All @@ -70,7 +74,6 @@ def test_graph_output_indexing(model2_multi_dict, model2, model1):
identifier='simple-graph',
input=model1,
outputs=[model2],
signature='**kwargs',
)
g.connect(model1, model2_multi_dict, on=(None, 'x'))
g.connect(model2_multi_dict, model2, on=('x', 'x'))
Expand Down
42 changes: 37 additions & 5 deletions test/unittest/component/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def fit(self, *args, **kwargs):
@patch.object(Mapping, '__call__')
def test_model_validate(mock_call):
# Check the metadadata recieves the correct values
model = Validator('test', object=object())
model = Validator('test', object=lambda x: x)
model._signature = 'singleton'
my_metric = MagicMock(spec=Metric)
my_metric.identifier = 'acc'
my_metric.return_value = 0.5
Expand Down Expand Up @@ -382,13 +383,12 @@ def test_sequential_model():
ObjectModel(
identifier='test-predictor-2',
object=lambda x: x + 1,
signature='singleton',
),
],
)

assert m.predict(x=1) == 4
assert m.predict_batches([((1,), {}) for _ in range(4)]) == [4, 4, 4, 4]
assert m.predict(1) == 4
assert m.predict_batches([1 for _ in range(4)]) == [4, 4, 4, 4]


def test_pm_predict_with_select_ids_multikey(monkeypatch, predict_mixin_multikey):
Expand Down Expand Up @@ -439,7 +439,8 @@ def _test(X, docs):
@pytest.fixture
def object_model():
return ObjectModel(
'test', object=lambda x: numpy.array(x) + 1, signature='singleton'
'test',
object=lambda x: numpy.array(x) + 1,
)


Expand Down Expand Up @@ -468,3 +469,34 @@ def test_object_model_as_a_listener(db, object_model):
r = results[0].unpack()
key = next(k for k in r.keys() if k.startswith('_outputs__test'))
assert all(np.allclose(r.unpack()[key], sample_data + 1) for r in results)


class MyModelSingleton(Model):
def predict(self, X):
return X


class MyModelArgs(Model):
def predict(self, X, Y):
return X


class MyModelArgsKwargs(Model):
def predict(self, X, Y=None):
return X


class MyModelKwargs(Model):
def predict(self, Y=None):
return Y


def test_signature_inference():
assert MyModelSingleton._signature == 'singleton'
assert MyModelArgs._signature == '*args'
assert MyModelArgsKwargs._signature == '*args,**kwargs'
assert MyModelKwargs._signature == '**kwargs'

m = ObjectModel('test', object=lambda x, y: x + y)

assert m.signature == '*args'
Loading

0 comments on commit ef827c4

Please sign in to comment.