Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate signature as argument #2657

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

#### Changed defaults / behaviours

...
- No need to add `.signature` to `Model` implementations

#### New Features & Functionality

Expand All @@ -21,8 +21,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fix the random silent failure bug when ibis creates tables.

...

## [0.5.0](https://github.com/superduper-io/superduper/compare/0.5.0...0.4.0]) (2024-Nov-02)

#### Changed defaults / behaviours
Expand All @@ -36,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Refactor secrets loading method.
- Add db.load in db wait
- Add model component cleanup
- Deprecate `signature` as parameter and auto-infer from `.predict`

#### New Features & Functionality

Expand Down
7 changes: 2 additions & 5 deletions superduper/components/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ class DocumentInput(Model):

spec: t.Union[str, t.List[str]]
identifier: str = '_input'
signature: Signature = 'singleton'

def __post_init__(self, db, example):
super().__post_init__(db, example)
Expand Down Expand Up @@ -232,12 +231,11 @@ class Graph(Model):
:param edges: Graph edges list.
:param input: Graph root node.
:param outputs: Graph output nodes.
:param signature: Graph signature.

Example:
-------
>> g = Graph(
>> identifier='simple-graph', input=model1, outputs=[model2], signature='*args'
>> identifier='simple-graph', input=model1, outputs=[model2],
>> )
>> g.connect(model1, model2)
>> assert g.predict(1) == [(4, 2)]
Expand All @@ -253,15 +251,14 @@ class Graph(Model):
)
input: Model
outputs: t.List[t.Union[str, Model]] = dc.field(default_factory=list)
signature: Signature = '*args,**kwargs'

def __post_init__(self, db, example):
self.G = nx.DiGraph()
self.nodes = {}
self.version = 0
self._db = None

self.signature = self.input.signature
self._signature = self.input.signature
if isinstance(self.outputs, list):
self.output_identifiers = [
o.identifier if isinstance(o, Model) else o for o in self.outputs
Expand Down
67 changes: 56 additions & 11 deletions superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def model(
output_schema: t.Optional[Schema] = None,
num_workers: int = 0,
example: t.Any | None = None,
signature: Signature = '*args,**kwargs',
):
"""Decorator to wrap a function with `ObjectModel`.

Expand All @@ -59,7 +58,6 @@ def model(
:param output_schema: Schema for the model outputs.
:param num_workers: Number of workers to use for parallel processing
:param example: Example to auto-determine the schema/ datatype.
:param signature: Signature for the model.
"""
if item is not None and (inspect.isclass(item) or callable(item)):
if inspect.isclass(item):
Expand All @@ -74,7 +72,6 @@ def object_model_factory(*args, **kwargs):
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)

return object_model_factory
Expand All @@ -88,7 +85,6 @@ def object_model_factory(*args, **kwargs):
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)
else:

Expand All @@ -105,7 +101,6 @@ def object_model_factory(*args, **kwargs):
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)

return object_model_factory
Expand All @@ -119,7 +114,6 @@ def object_model_factory(*args, **kwargs):
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)

return decorated_function
Expand Down Expand Up @@ -352,13 +346,32 @@ def __new__(mcls, name, bases, dct):
if 'init' in dct:
dct['init'] = init_decorator(dct['init'])
cls = super().__new__(mcls, name, bases, dct)

signature = inspect.signature(cls.predict)
pos = []
kw = []
for k in signature.parameters:
if k == 'self':
continue
if signature.parameters[k].default == inspect._empty:
pos.append(k)
else:
kw.append(k)
if len(pos) == 1 and not kw:
cls._signature = 'singleton'
elif pos and not kw:
cls._signature = '*args'
elif pos and kw:
cls._signature = '*args,**kwargs'
else:
assert not pos and kw
cls._signature = '**kwargs'
return cls


class Model(Component, metaclass=ModelMeta):
"""Base class for components which can predict.

:param signature: Model signature.
:param datatype: DataType instance.
:param output_schema: Output schema (mapping of encoders).
:param model_update_kwargs: The kwargs to use for model update.
Expand All @@ -378,7 +391,6 @@ class Model(Component, metaclass=ModelMeta):

breaks: t.ClassVar[t.Sequence] = ('trainer',)
type_id: t.ClassVar[str] = 'model'
signature: Signature = '*args,**kwargs'
datatype: EncoderArg = None
output_schema: t.Optional[Schema] = None
model_update_kwargs: t.Dict = dc.field(default_factory=dict)
Expand Down Expand Up @@ -408,6 +420,10 @@ def cleanup(self, db: "Datalayer") -> None:
super().cleanup(db=db)
db.cluster.compute.drop(self)

@property
def signature(self):
return self._signature

@property
def inputs(self) -> Inputs:
"""Instance of `Inputs` to represent model params."""
Expand Down Expand Up @@ -1047,6 +1063,10 @@ class ImportedModel(Model):
object: Leaf
method: t.Optional[str] = None

def __post_init__(self, db, example):
super().__post_init__(db, example)
self._inferred_signature = None

@staticmethod
def _infer_signature(object):
# find positional and key-word parameters from the object
Expand All @@ -1067,6 +1087,13 @@ def _infer_signature(object):
return '**kwargs'
return '*args,**kwargs'

@property
@ensure_initialized
def signature(self):
if self._inferred_signature is None:
self._inferred_signature = self._infer_signature(self.object)
return self._inferred_signature

@property
def outputs(self):
"""Get an instance of ``IndexableNode`` to index outputs."""
Expand Down Expand Up @@ -1206,6 +1233,7 @@ class QueryModel(Model):
:param preprocess: Preprocess callable
:param postprocess: Postprocess callable
:param select: query used to find data (can include `like`)
:param signature: signature to use
"""

preprocess: t.Optional[t.Callable] = None
Expand Down Expand Up @@ -1267,10 +1295,13 @@ class SequentialModel(Model):
models: t.List[Model]

def __post_init__(self, db, example):
self.signature = self.models[0].signature
self.datatype = self.models[-1].datatype
return super().__post_init__(db, example)

@property
def signature(self):
return self.models[0].signature

@property
def inputs(self) -> Inputs:
"""Instance of `Inputs` to represent model params."""
Expand Down Expand Up @@ -1300,7 +1331,22 @@ def predict(self, *args, **kwargs):
:param args: Positional arguments to predict on.
:param kwargs: Keyword arguments to predict on.
"""
return self.predict_batches([(args, kwargs)])[0]
for i, p in enumerate(self.models):
assert isinstance(p, Model), f'Expected `Model`, got {type(p)}'
if i == 0:
out = p.predict(*args, **kwargs)
else:
if p.signature == 'singleton':
out = p.predict(out)
elif p.signature == '*args':
out = p.predict(*out)
elif p.signature == '**kwargs':
out = p.predict(**out)
else:
msg = 'Model defines a predict with no free parameters'
assert p.signature == '*args,**kwargs', msg
out = p.predict(*out[0], **out[1])
return out

def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List:
"""Execute on series of data point defined in dataset.
Expand Down Expand Up @@ -1367,7 +1413,6 @@ class RAGModel(Model):
breaks: t.ClassVar[t.Sequence] = ('llm', 'prompt_template')

prompt_template: str
signature: str = 'singleton'
select: Query
key: str
llm: Model
Expand Down
5 changes: 2 additions & 3 deletions superduper/rest/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,9 @@ def db_remove(
@app.add('/db/show_template', method='get')
def db_show_template(
identifier: str,
type_id: str = 'template',
db: 'Datalayer' = DatalayerDependency(),
):
template: Template = db.load(type_id=type_id, identifier=identifier)
template: Template = db.load(type_id='template', identifier=identifier)
return template.form_template

@app.add('/db/edit', method='get')
Expand All @@ -293,7 +292,7 @@ def db_edit(
db: 'Datalayer' = DatalayerDependency(),
):
component = db.load(type_id, identifier)
template = db.load('template', component.build_template)
template: Template = db.load('template', component.build_template)
form = template.form_template
form['_variables'] = component.build_variables
return form
Expand Down
7 changes: 4 additions & 3 deletions test/unittest/base/test_datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ def test_replace(db: Datalayer):
object=lambda x: x + 1,
identifier='m',
datatype=DataType(identifier='base'),
signature='singleton',
)
model.version = 0
db.apply(model)
Expand All @@ -484,7 +483,8 @@ def test_replace(db: Datalayer):
assert db.load('model', 'm').predict(1) == 2

new_model = ObjectModel(
object=lambda x: x + 2, identifier='m', signature='singleton'
object=lambda x: x + 2,
identifier='m',
)
new_model.version = 0
db.replace(new_model)
Expand All @@ -496,7 +496,8 @@ def test_replace(db: Datalayer):

# replace the last version of the model
new_model = ObjectModel(
object=lambda x: x + 3, identifier='m', signature='singleton'
object=lambda x: x + 3,
identifier='m',
)
new_model.version = 0
db.replace(new_model)
Expand Down
11 changes: 7 additions & 4 deletions test/unittest/component/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def model1(db):
def model_object(x):
return x + 1

model = ObjectModel(identifier='m1', object=model_object, signature='singleton')
model = ObjectModel(identifier='m1', object=model_object)
yield model


Expand Down Expand Up @@ -53,13 +53,17 @@ def model_object(x, y):

def test_simple_graph(model1, model2):
g = Graph(
identifier='simple-graph', input=model1, outputs=model2, signature='*args'
identifier='simple-graph',
input=model1,
outputs=model2,
)
g.connect(model1, model2)
assert g.predict(1) == (4, 2)

g = Graph(
identifier='simple-graph', input=model1, outputs=model2, signature='*args'
identifier='simple-graph',
input=model1,
outputs=model2,
)
g.connect(model1, model2)
assert g.predict_batches([1, 2, 3]) == [(4, 2), (5, 3), (6, 4)]
Expand All @@ -70,7 +74,6 @@ def test_graph_output_indexing(model2_multi_dict, model2, model1):
identifier='simple-graph',
input=model1,
outputs=[model2],
signature='**kwargs',
)
g.connect(model1, model2_multi_dict, on=(None, 'x'))
g.connect(model2_multi_dict, model2, on=('x', 'x'))
Expand Down
Loading
Loading