From 21d4b4b417941504562c97a77663b0bfe56fb546 Mon Sep 17 00:00:00 2001 From: Duncan Blythe Date: Fri, 29 Nov 2024 15:05:16 +0100 Subject: [PATCH] Deprecate signature as argument --- CHANGELOG.md | 5 +- superduper/components/graph.py | 7 +-- superduper/components/model.py | 67 +++++++++++++++++++++----- superduper/rest/build.py | 5 +- test/unittest/base/test_datalayer.py | 7 +-- test/unittest/component/test_graph.py | 11 +++-- test/unittest/component/test_model.py | 42 ++++++++++++++-- test/unittest/component/test_plugin.py | 10 ++-- test/unittest/ext/test_vanilla.py | 2 +- 9 files changed, 116 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a829a0884..5591ab61a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 #### Changed defaults / behaviours -... +- No need to add `.signature` to `Model` implementations #### New Features & Functionality @@ -21,8 +21,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix the random silent failure bug when ibis creates tables. -... - ## [0.5.0](https://github.com/superduper-io/superduper/compare/0.5.0...0.4.0]) (2024-Nov-02) #### Changed defaults / behaviours @@ -36,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Refactor secrets loading method. - Add db.load in db wait - Add model component cleanup +- Deprecate `signature` as parameter and auto-infer from `.predict` #### New Features & Functionality diff --git a/superduper/components/graph.py b/superduper/components/graph.py index 94555033b..9446f552b 100644 --- a/superduper/components/graph.py +++ b/superduper/components/graph.py @@ -197,7 +197,6 @@ class DocumentInput(Model): spec: t.Union[str, t.List[str]] identifier: str = '_input' - signature: Signature = 'singleton' def __post_init__(self, db, example): super().__post_init__(db, example) @@ -232,12 +231,11 @@ class Graph(Model): :param edges: Graph edges list. :param input: Graph root node. :param outputs: Graph output nodes. - :param signature: Graph signature. Example: ------- >> g = Graph( - >> identifier='simple-graph', input=model1, outputs=[model2], signature='*args' + >> identifier='simple-graph', input=model1, outputs=[model2], >> ) >> g.connect(model1, model2) >> assert g.predict(1) == [(4, 2)] @@ -253,7 +251,6 @@ class Graph(Model): ) input: Model outputs: t.List[t.Union[str, Model]] = dc.field(default_factory=list) - signature: Signature = '*args,**kwargs' def __post_init__(self, db, example): self.G = nx.DiGraph() @@ -261,7 +258,7 @@ def __post_init__(self, db, example): self.version = 0 self._db = None - self.signature = self.input.signature + self._signature = self.input.signature if isinstance(self.outputs, list): self.output_identifiers = [ o.identifier if isinstance(o, Model) else o for o in self.outputs diff --git a/superduper/components/model.py b/superduper/components/model.py index 3c08d6070..e69ce85d7 100644 --- a/superduper/components/model.py +++ b/superduper/components/model.py @@ -45,7 +45,6 @@ def model( output_schema: t.Optional[Schema] = None, num_workers: int = 0, example: t.Any | None = None, - signature: Signature = '*args,**kwargs', ): """Decorator to wrap a function with `ObjectModel`. @@ -59,7 +58,6 @@ def model( :param output_schema: Schema for the model outputs. :param num_workers: Number of workers to use for parallel processing :param example: Example to auto-determine the schema/ datatype. - :param signature: Signature for the model. """ if item is not None and (inspect.isclass(item) or callable(item)): if inspect.isclass(item): @@ -74,7 +72,6 @@ def object_model_factory(*args, **kwargs): output_schema=output_schema, num_workers=num_workers, example=example, - signature=signature, ) return object_model_factory @@ -88,7 +85,6 @@ def object_model_factory(*args, **kwargs): output_schema=output_schema, num_workers=num_workers, example=example, - signature=signature, ) else: @@ -105,7 +101,6 @@ def object_model_factory(*args, **kwargs): output_schema=output_schema, num_workers=num_workers, example=example, - signature=signature, ) return object_model_factory @@ -119,7 +114,6 @@ def object_model_factory(*args, **kwargs): output_schema=output_schema, num_workers=num_workers, example=example, - signature=signature, ) return decorated_function @@ -352,13 +346,32 @@ def __new__(mcls, name, bases, dct): if 'init' in dct: dct['init'] = init_decorator(dct['init']) cls = super().__new__(mcls, name, bases, dct) + + signature = inspect.signature(cls.predict) + pos = [] + kw = [] + for k in signature.parameters: + if k == 'self': + continue + if signature.parameters[k].default == inspect._empty: + pos.append(k) + else: + kw.append(k) + if len(pos) == 1 and not kw: + cls._signature = 'singleton' + elif pos and not kw: + cls._signature = '*args' + elif pos and kw: + cls._signature = '*args,**kwargs' + else: + assert not pos and kw + cls._signature = '**kwargs' return cls class Model(Component, metaclass=ModelMeta): """Base class for components which can predict. - :param signature: Model signature. :param datatype: DataType instance. :param output_schema: Output schema (mapping of encoders). :param model_update_kwargs: The kwargs to use for model update. @@ -378,7 +391,6 @@ class Model(Component, metaclass=ModelMeta): breaks: t.ClassVar[t.Sequence] = ('trainer',) type_id: t.ClassVar[str] = 'model' - signature: Signature = '*args,**kwargs' datatype: EncoderArg = None output_schema: t.Optional[Schema] = None model_update_kwargs: t.Dict = dc.field(default_factory=dict) @@ -408,6 +420,10 @@ def cleanup(self, db: "Datalayer") -> None: super().cleanup(db=db) db.cluster.compute.drop(self) + @property + def signature(self): + return self._signature + @property def inputs(self) -> Inputs: """Instance of `Inputs` to represent model params.""" @@ -1047,6 +1063,10 @@ class ImportedModel(Model): object: Leaf method: t.Optional[str] = None + def __post_init__(self, db, example): + super().__post_init__(db, example) + self._inferred_signature = None + @staticmethod def _infer_signature(object): # find positional and key-word parameters from the object @@ -1067,6 +1087,13 @@ def _infer_signature(object): return '**kwargs' return '*args,**kwargs' + @property + @ensure_initialized + def signature(self): + if self._inferred_signature is None: + self._inferred_signature = self._infer_signature(self.object) + return self._inferred_signature + @property def outputs(self): """Get an instance of ``IndexableNode`` to index outputs.""" @@ -1206,6 +1233,7 @@ class QueryModel(Model): :param preprocess: Preprocess callable :param postprocess: Postprocess callable :param select: query used to find data (can include `like`) + :param signature: signature to use """ preprocess: t.Optional[t.Callable] = None @@ -1267,10 +1295,13 @@ class SequentialModel(Model): models: t.List[Model] def __post_init__(self, db, example): - self.signature = self.models[0].signature self.datatype = self.models[-1].datatype return super().__post_init__(db, example) + @property + def signature(self): + return self.models[0].signature + @property def inputs(self) -> Inputs: """Instance of `Inputs` to represent model params.""" @@ -1300,7 +1331,22 @@ def predict(self, *args, **kwargs): :param args: Positional arguments to predict on. :param kwargs: Keyword arguments to predict on. """ - return self.predict_batches([(args, kwargs)])[0] + for i, p in enumerate(self.models): + assert isinstance(p, Model), f'Expected `Model`, got {type(p)}' + if i == 0: + out = p.predict(*args, **kwargs) + else: + if p.signature == 'singleton': + out = p.predict(out) + elif p.signature == '*args': + out = p.predict(*out) + elif p.signature == '**kwargs': + out = p.predict(**out) + else: + msg = 'Model defines a predict with no free parameters' + assert p.signature == '*args,**kwargs', msg + out = p.predict(*out[0], **out[1]) + return out def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: """Execute on series of data point defined in dataset. @@ -1367,7 +1413,6 @@ class RAGModel(Model): breaks: t.ClassVar[t.Sequence] = ('llm', 'prompt_template') prompt_template: str - signature: str = 'singleton' select: Query key: str llm: Model diff --git a/superduper/rest/build.py b/superduper/rest/build.py index fea407f9d..acdcabc88 100644 --- a/superduper/rest/build.py +++ b/superduper/rest/build.py @@ -280,10 +280,9 @@ def db_remove( @app.add('/db/show_template', method='get') def db_show_template( identifier: str, - type_id: str = 'template', db: 'Datalayer' = DatalayerDependency(), ): - template: Template = db.load(type_id=type_id, identifier=identifier) + template: Template = db.load(type_id='template', identifier=identifier) return template.form_template @app.add('/db/edit', method='get') @@ -293,7 +292,7 @@ def db_edit( db: 'Datalayer' = DatalayerDependency(), ): component = db.load(type_id, identifier) - template = db.load('template', component.build_template) + template: Template = db.load('template', component.build_template) form = template.form_template form['_variables'] = component.build_variables return form diff --git a/test/unittest/base/test_datalayer.py b/test/unittest/base/test_datalayer.py index bdb5a8b95..3cc82e7b0 100644 --- a/test/unittest/base/test_datalayer.py +++ b/test/unittest/base/test_datalayer.py @@ -475,7 +475,6 @@ def test_replace(db: Datalayer): object=lambda x: x + 1, identifier='m', datatype=DataType(identifier='base'), - signature='singleton', ) model.version = 0 db.apply(model) @@ -484,7 +483,8 @@ def test_replace(db: Datalayer): assert db.load('model', 'm').predict(1) == 2 new_model = ObjectModel( - object=lambda x: x + 2, identifier='m', signature='singleton' + object=lambda x: x + 2, + identifier='m', ) new_model.version = 0 db.replace(new_model) @@ -496,7 +496,8 @@ def test_replace(db: Datalayer): # replace the last version of the model new_model = ObjectModel( - object=lambda x: x + 3, identifier='m', signature='singleton' + object=lambda x: x + 3, + identifier='m', ) new_model.version = 0 db.replace(new_model) diff --git a/test/unittest/component/test_graph.py b/test/unittest/component/test_graph.py index eccca41e9..d651bd25b 100644 --- a/test/unittest/component/test_graph.py +++ b/test/unittest/component/test_graph.py @@ -11,7 +11,7 @@ def model1(db): def model_object(x): return x + 1 - model = ObjectModel(identifier='m1', object=model_object, signature='singleton') + model = ObjectModel(identifier='m1', object=model_object) yield model @@ -53,13 +53,17 @@ def model_object(x, y): def test_simple_graph(model1, model2): g = Graph( - identifier='simple-graph', input=model1, outputs=model2, signature='*args' + identifier='simple-graph', + input=model1, + outputs=model2, ) g.connect(model1, model2) assert g.predict(1) == (4, 2) g = Graph( - identifier='simple-graph', input=model1, outputs=model2, signature='*args' + identifier='simple-graph', + input=model1, + outputs=model2, ) g.connect(model1, model2) assert g.predict_batches([1, 2, 3]) == [(4, 2), (5, 3), (6, 4)] @@ -70,7 +74,6 @@ def test_graph_output_indexing(model2_multi_dict, model2, model1): identifier='simple-graph', input=model1, outputs=[model2], - signature='**kwargs', ) g.connect(model1, model2_multi_dict, on=(None, 'x')) g.connect(model2_multi_dict, model2, on=('x', 'x')) diff --git a/test/unittest/component/test_model.py b/test/unittest/component/test_model.py index c674d8986..b9b89882a 100644 --- a/test/unittest/component/test_model.py +++ b/test/unittest/component/test_model.py @@ -220,7 +220,8 @@ def fit(self, *args, **kwargs): @patch.object(Mapping, '__call__') def test_model_validate(mock_call): # Check the metadadata recieves the correct values - model = Validator('test', object=object()) + model = Validator('test', object=lambda x: x) + model._signature = 'singleton' my_metric = MagicMock(spec=Metric) my_metric.identifier = 'acc' my_metric.return_value = 0.5 @@ -382,13 +383,12 @@ def test_sequential_model(): ObjectModel( identifier='test-predictor-2', object=lambda x: x + 1, - signature='singleton', ), ], ) - assert m.predict(x=1) == 4 - assert m.predict_batches([((1,), {}) for _ in range(4)]) == [4, 4, 4, 4] + assert m.predict(1) == 4 + assert m.predict_batches([1 for _ in range(4)]) == [4, 4, 4, 4] def test_pm_predict_with_select_ids_multikey(monkeypatch, predict_mixin_multikey): @@ -439,7 +439,8 @@ def _test(X, docs): @pytest.fixture def object_model(): return ObjectModel( - 'test', object=lambda x: numpy.array(x) + 1, signature='singleton' + 'test', + object=lambda x: numpy.array(x) + 1, ) @@ -468,3 +469,34 @@ def test_object_model_as_a_listener(db, object_model): r = results[0].unpack() key = next(k for k in r.keys() if k.startswith('_outputs__test')) assert all(np.allclose(r.unpack()[key], sample_data + 1) for r in results) + + +class MyModelSingleton(Model): + def predict(self, X): + return X + + +class MyModelArgs(Model): + def predict(self, X, Y): + return X + + +class MyModelArgsKwargs(Model): + def predict(self, X, Y=None): + return X + + +class MyModelKwargs(Model): + def predict(self, Y=None): + return Y + + +def test_signature_inference(): + assert MyModelSingleton._signature == 'singleton' + assert MyModelArgs._signature == '*args' + assert MyModelArgsKwargs._signature == '*args,**kwargs' + assert MyModelKwargs._signature == '**kwargs' + + m = ObjectModel('test', object=lambda x, y: x + y) + + assert m.signature == '*args' diff --git a/test/unittest/component/test_plugin.py b/test/unittest/component/test_plugin.py index 8460a5da9..f4b0262da 100644 --- a/test/unittest/component/test_plugin.py +++ b/test/unittest/component/test_plugin.py @@ -10,7 +10,7 @@ from superduper import Model class PModel(Model): - def predict(self) -> int: + def predict(self, X) -> int: return "{plugin_type}" """ @@ -82,7 +82,7 @@ def test_module(tmpdir): from p_module import PModel model = PModel("test") - assert model.predict() == "module" + assert model.predict(2) == "module" def test_package(tmpdir): @@ -96,7 +96,7 @@ def test_package(tmpdir): from p_package.p_package import PModel model = PModel("test") - assert model.predict() == "package" + assert model.predict(2) == "package" def test_directory(tmpdir): @@ -111,7 +111,7 @@ def test_directory(tmpdir): model = PModel("test") - assert model.predict() == "directory" + assert model.predict(2) == "directory" def test_repeated_loading(tmpdir): @@ -151,7 +151,7 @@ def test_import(tmpdir): from p_import.p_import import PModel model = PModel("test") - assert model.predict() == "import" + assert model.predict(2) == "import" def test_apply(db, tmpdir): diff --git a/test/unittest/ext/test_vanilla.py b/test/unittest/ext/test_vanilla.py index ac529d301..899b07032 100644 --- a/test/unittest/ext/test_vanilla.py +++ b/test/unittest/ext/test_vanilla.py @@ -21,7 +21,7 @@ def test_function_predict(): def test_function_predict_batches(): - function = ObjectModel(object=lambda x: x, identifier='test', signature='singleton') + function = ObjectModel(object=lambda x: x, identifier='test') assert function.predict_batches([1, 1]) == [1, 1]