From f6e56afcf3ab939218967e2658baa0937fb76f54 Mon Sep 17 00:00:00 2001 From: danielward27 Date: Wed, 25 Oct 2023 14:53:12 +0100 Subject: [PATCH] Use property --- equinox/_jit.py | 2 +- equinox/_module.py | 12 +++++++----- tests/test_module.py | 40 ++++++++++++++++++++++++++++++++-------- 3 files changed, 40 insertions(+), 14 deletions(-) diff --git a/equinox/_jit.py b/equinox/_jit.py index b2fd50bb..b347285c 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -274,7 +274,7 @@ def f(x, y): # both args traced if arrays, static if non-arrays "`jitkwargs` cannot contain 'static_argnums', 'static_argnames' or " "'donate_argnums'" ) - signature = inspect.signature(fun.__call__) + signature = inspect.signature(fun) if donate not in {"all", "warn", "none"}: raise ValueError( diff --git a/equinox/_module.py b/equinox/_module.py index 410a9ce4..6104a71f 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -317,11 +317,6 @@ def __init__(self, *args, **kwargs): # TODO: is this next line still necessary? cls.__init__.__module__ = cls.__module__ - # Assign __signature__ to match dataclass - sig = inspect.signature(cls.__init__) - params = list(sig.parameters.values())[1:] # Remove self to match dataclass sig - cls.__signature__ = sig.replace(parameters=params) - # [Step 5] We support an optional `strict` mode for Rust-like strictness in the # type checking. # In practice this is probably too much for your average user, but it's a great @@ -436,6 +431,13 @@ def __init__(self, *args, **kwargs): # Done! return cls + @property + def __signature__(cls): + # Use signature of __init__ method for non-callable equinox modules + sig = inspect.signature(cls.__init__) + params = list(sig.parameters.values())[1:] # Remove self parameter + return sig.replace(parameters=params) + # This method is called whenever you initialise a module: `MyModule(...)` def __call__(cls, *args, **kwargs): if _is_force_abstract[cls]: diff --git a/tests/test_module.py b/tests/test_module.py index c965adfc..0718bc48 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -993,22 +993,46 @@ class ExampleModel(eqx.Module): assert static_field.metadata == dict(foo=False, static=True) -def test_signature(): - class Foo(eqx.Module): +def signature_test_cases(): + @dataclasses.dataclass + class FooDataClass: + a: int + + class FooModule(eqx.Module): a: int + @dataclasses.dataclass + class CustomInitDataClass: + def __init__(self, a: int): + pass + class CustomInitModule(eqx.Module): + def __init__(self, a: int): + pass + + @dataclasses.dataclass + class CallableDataClass: a: int - def __init__(self, b: int): - self.a = b + def __call__(self, b: int): + pass - class FooCallable(eqx.Module): + class CallableModule(eqx.Module): a: int def __call__(self, b: int): pass - for T in [Foo, FooCallable, CustomInitModule, FooCallable(1)]: - print(str(T)) - print(str(inspect.signature(T))) + test_cases = [ + (FooDataClass, FooModule), + (CustomInitDataClass, CustomInitModule), + (CallableDataClass, CallableModule), + (CallableDataClass(1), CallableModule(1)), + ] + return test_cases + + +@pytest.mark.parametrize(("dataclass", "module"), signature_test_cases()) +def test_signature(dataclass, module): + # Check module signature matches dataclass signatures. + assert inspect.signature(dataclass) == inspect.signature(module)