Skip to content

Commit

Permalink
Use property
Browse files Browse the repository at this point in the history
  • Loading branch information
danielward27 committed Oct 25, 2023
1 parent be554d5 commit f6e56af
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 14 deletions.
2 changes: 1 addition & 1 deletion equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def f(x, y): # both args traced if arrays, static if non-arrays
"`jitkwargs` cannot contain 'static_argnums', 'static_argnames' or "
"'donate_argnums'"
)
signature = inspect.signature(fun.__call__)
signature = inspect.signature(fun)

if donate not in {"all", "warn", "none"}:
raise ValueError(
Expand Down
12 changes: 7 additions & 5 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,6 @@ def __init__(self, *args, **kwargs):
# TODO: is this next line still necessary?
cls.__init__.__module__ = cls.__module__

# Assign __signature__ to match dataclass
sig = inspect.signature(cls.__init__)
params = list(sig.parameters.values())[1:] # Remove self to match dataclass sig
cls.__signature__ = sig.replace(parameters=params)

# [Step 5] We support an optional `strict` mode for Rust-like strictness in the
# type checking.
# In practice this is probably too much for your average user, but it's a great
Expand Down Expand Up @@ -436,6 +431,13 @@ def __init__(self, *args, **kwargs):
# Done!
return cls

@property
def __signature__(cls):
# Use signature of __init__ method for non-callable equinox modules
sig = inspect.signature(cls.__init__)
params = list(sig.parameters.values())[1:] # Remove self parameter
return sig.replace(parameters=params)

# This method is called whenever you initialise a module: `MyModule(...)`
def __call__(cls, *args, **kwargs):
if _is_force_abstract[cls]:
Expand Down
40 changes: 32 additions & 8 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,22 +993,46 @@ class ExampleModel(eqx.Module):
assert static_field.metadata == dict(foo=False, static=True)


def test_signature():
class Foo(eqx.Module):
def signature_test_cases():
@dataclasses.dataclass
class FooDataClass:
a: int

class FooModule(eqx.Module):
a: int

@dataclasses.dataclass
class CustomInitDataClass:
def __init__(self, a: int):
pass

class CustomInitModule(eqx.Module):
def __init__(self, a: int):
pass

@dataclasses.dataclass
class CallableDataClass:
a: int

def __init__(self, b: int):
self.a = b
def __call__(self, b: int):
pass

class FooCallable(eqx.Module):
class CallableModule(eqx.Module):
a: int

def __call__(self, b: int):
pass

for T in [Foo, FooCallable, CustomInitModule, FooCallable(1)]:
print(str(T))
print(str(inspect.signature(T)))
test_cases = [
(FooDataClass, FooModule),
(CustomInitDataClass, CustomInitModule),
(CallableDataClass, CallableModule),
(CallableDataClass(1), CallableModule(1)),
]
return test_cases


@pytest.mark.parametrize(("dataclass", "module"), signature_test_cases())
def test_signature(dataclass, module):
# Check module signature matches dataclass signatures.
assert inspect.signature(dataclass) == inspect.signature(module)

0 comments on commit f6e56af

Please sign in to comment.