From 39ac206e462ecdaa8358c233a5a96193f078856d Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 2 Jan 2025 17:54:40 +0800 Subject: [PATCH] feat(pre-commit): add `mypy` to pre-commit hook --- .pre-commit-config.yaml | 11 ++++++++++- optree/typing.py | 11 +++++++---- tests/test_accessor.py | 1 - tests/test_dataclasses.py | 1 - tests/test_typing.py | 1 - 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 22439d6e..3b98ce3f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.9.1 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -63,6 +63,15 @@ repos: hooks: - id: codespell additional_dependencies: [".[toml]"] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.14.1 + hooks: + - id: mypy + exclude: | + (?x)( + ^tests/| + ^setup.py$ + ) - repo: local hooks: - id: pylint diff --git a/optree/typing.py b/optree/typing.py index bec28fb0..d2463aae 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -361,7 +361,10 @@ def __call__(self, metadata: MetaData, children: Children[T], /) -> Collection[T """Unflatten the children and metadata back into the container.""" -def _override_with_(cxx_implementation: F, /) -> Callable[[F], F]: +def _override_with_( + cxx_implementation: Callable[P, T], + /, +) -> Callable[[Callable[P, T]], Callable[P, T]]: """Decorator to override the Python implementation with the C++ implementation. >>> @_override_with_(any) @@ -375,15 +378,15 @@ def _override_with_(cxx_implementation: F, /) -> Callable[[F], F]: True """ - def wrapper(python_implementation: F, /) -> F: + def wrapper(python_implementation: Callable[P, T], /) -> Callable[P, T]: @functools.wraps(python_implementation) - def wrapped(*args: Any, **kwargs: Any) -> Any: + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: return cxx_implementation(*args, **kwargs) wrapped.__cxx_implementation__ = cxx_implementation # type: ignore[attr-defined] wrapped.__python_implementation__ = python_implementation # type: ignore[attr-defined] - return wrapped # type: ignore[return-value] + return wrapped return wrapper diff --git a/tests/test_accessor.py b/tests/test_accessor.py index ba6f7627..d2e95575 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -416,7 +416,6 @@ class SubclassedAutoEntry(optree.AutoEntry): def test_flattened_entry_call(): - @optree.register_pytree_node_class(namespace='namespace') class MyObject: def __init__(self, x, y, z): diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index ab61a473..88bf33df 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -628,7 +628,6 @@ def test_make_dataclass_with_duplicate_registrations(): TypeError, match=r'@optree\.dataclasses\.dataclass\(\) cannot be applied to .* more than once\.', ): - optree.dataclasses.dataclass(Foo2, namespace='other-error') Foo = optree.register_pytree_node_class(namespace='other-namespace')( # noqa: N806 diff --git a/tests/test_typing.py b/tests/test_typing.py index c3d8998b..21d61b42 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -88,7 +88,6 @@ def is_namedtuple_(obj): optree.is_namedtuple_instance.__python_implementation__, ), ): - assert not is_namedtuple_((1, 2)) assert not is_namedtuple_([1, 2]) assert not is_namedtuple_(sys.float_info)