Skip to content

Commit

Permalink
Module signatures (#573)
Browse files Browse the repository at this point in the history
* Fix contributing.md typo

* Assign __signature__ in modules

* Revert "Fix contributing.md typo"

This reverts commit 465313a.

* Print stuff

* Use property
  • Loading branch information
danielward27 authored Oct 25, 2023
1 parent 459941d commit 05d07bf
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
7 changes: 7 additions & 0 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,13 @@ def __init__(self, *args, **kwargs):
# Done!
return cls

@property
def __signature__(cls):
# Use signature of __init__ method for non-callable equinox modules
sig = inspect.signature(cls.__init__)
params = list(sig.parameters.values())[1:] # Remove self parameter
return sig.replace(parameters=params)

# This method is called whenever you initialise a module: `MyModule(...)`
def __call__(cls, *args, **kwargs):
if _is_force_abstract[cls]:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import dataclasses
import functools as ft
import inspect
from collections.abc import Callable
from dataclasses import InitVar
from typing import Any, Optional
Expand Down Expand Up @@ -990,3 +991,48 @@ class ExampleModel(eqx.Module):
dynamic_field, static_field = dataclasses.fields(model)
assert dynamic_field.metadata == dict(foo=True)
assert static_field.metadata == dict(foo=False, static=True)


def signature_test_cases():
@dataclasses.dataclass
class FooDataClass:
a: int

class FooModule(eqx.Module):
a: int

@dataclasses.dataclass
class CustomInitDataClass:
def __init__(self, a: int):
pass

class CustomInitModule(eqx.Module):
def __init__(self, a: int):
pass

@dataclasses.dataclass
class CallableDataClass:
a: int

def __call__(self, b: int):
pass

class CallableModule(eqx.Module):
a: int

def __call__(self, b: int):
pass

test_cases = [
(FooDataClass, FooModule),
(CustomInitDataClass, CustomInitModule),
(CallableDataClass, CallableModule),
(CallableDataClass(1), CallableModule(1)),
]
return test_cases


@pytest.mark.parametrize(("dataclass", "module"), signature_test_cases())
def test_signature(dataclass, module):
# Check module signature matches dataclass signatures.
assert inspect.signature(dataclass) == inspect.signature(module)

0 comments on commit 05d07bf

Please sign in to comment.