From 617e0434d417f53756c17e52bf31e0aedb598ded Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 7 Sep 2023 04:25:18 -0700 Subject: [PATCH] Added equinox.internal.closure_to_pytree --- .pre-commit-config.yaml | 2 +- equinox/internal/__init__.py | 1 + equinox/internal/_closure_to_pytree.py | 125 +++++++++++++++++++++++++ tests/requirements.txt | 1 + tests/test_closure_to_pytree.py | 45 +++++++++ 5 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 equinox/internal/_closure_to_pytree.py create mode 100644 tests/test_closure_to_pytree.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index abd30f2d..52f1f122 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: rev: v1.1.315 hooks: - id: pyright - additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, tensorflow, tf2onnx, typing_extensions] + additional_dependencies: [beartype, einops, jax, jaxtyping, optax, pytest, tensorflow, tf2onnx, typing_extensions] - repo: https://github.com/nbQA-dev/nbQA rev: 1.6.3 hooks: diff --git a/equinox/internal/__init__.py b/equinox/internal/__init__.py index 48a29731..84b86f7f 100644 --- a/equinox/internal/__init__.py +++ b/equinox/internal/__init__.py @@ -41,6 +41,7 @@ store_dce as store_dce, ) from ..debug._announce_transform import announce_jaxpr_p as announce_jaxpr_p +from ._closure_to_pytree import closure_to_pytree as closure_to_pytree from ._finalise_jaxpr import ( finalise_eval_jaxpr as finalise_eval_jaxpr, finalise_fn as finalise_fn, diff --git a/equinox/internal/_closure_to_pytree.py b/equinox/internal/_closure_to_pytree.py new file mode 100644 index 00000000..318991af --- /dev/null +++ b/equinox/internal/_closure_to_pytree.py @@ -0,0 +1,125 @@ +# This is some mildly unpleasant code. +# +# Basically, Optax make the decision *not* to register their optimisers as PyTrees. +# This means that we often end up with spurious recompilation, just because a learning +# rate changed. That results in a new optimiser instance, which is just a function and +# is treated statically. +# +# So here we simply replace all function closures with pytrees, with each of their cell +# contents as their subnodes. + +import types +from typing import Any, Optional + +import jax.tree_util as jtu + +from .._module import Module + + +def _make_cell(val): + fn = lambda: val + return fn.__closure__[0] # pyright: ignore + + +def _adjust_function_closure(fn, closure): + out = types.FunctionType( + code=fn.__code__, + globals=fn.__globals__, + name=fn.__name__, + argdefs=fn.__defaults__, + closure=closure, + ) + out.__module__ = fn.__module__ + out.__qualname__ = fn.__qualname__ + out.__doc__ = fn.__doc__ + out.__annotations__.update(fn.__annotations__) + if fn.__kwdefaults__ is not None: + out.__kwdefaults__ = fn.__kwdefaults__.copy() + return out + + +# Not a pytree. +# Used so that two different local functions, with different identities, can still +# compare equal. This is needed as these leaves are compared statically when +# filter-jit'ing. +class _FunctionWithEquality: + def __init__(self, fn: types.FunctionType): + self.fn = fn + + def information(self): + return self.fn.__qualname__, self.fn.__module__ + + def __hash__(self): + return hash(self.information()) + + def __eq__(self, other): + return type(self) == type(other) and self.information() == other.information() + + +class _Closure(Module): + fn: _FunctionWithEquality + contents: Optional[tuple[Any, ...]] + + def __init__(self, fn: types.FunctionType): + self.fn = _FunctionWithEquality(fn) + if fn.__closure__ is None: + contents = None + else: + contents = tuple( + closure_to_pytree(cell.cell_contents) for cell in fn.__closure__ + ) + self.contents = contents + + def __call__(self, *args, **kwargs): + if self.contents is None: + closure = None + else: + closure = tuple(_make_cell(contents) for contents in self.contents) + fn = _adjust_function_closure(self.fn.fn, closure) + return fn(*args, **kwargs) + + +def _fixup_closure(leaf): + if isinstance(leaf, types.FunctionType): + return _Closure(leaf) + else: + return leaf + + +def closure_to_pytree(tree): + """Convert all function closures into pytree nodes. + + **Arguments:** + + - `tree`: Any pytree. + + **Returns:** + + A copy of `tree`, where all function closures have been replaced by a new object + that is (a) callable like the original function, but (b) iterates over its + `__closure__` as subnodes in the pytree. + + !!! Example + + ```python + def some_fn(): + a = jnp.array(1.) + + @closure_to_pytree + def f(x): + return x + a + + print(jax.tree_util.tree_leaves(f)) # prints out `a` + ``` + + !!! Warning + + One annoying technical detail in the above example: we had to wrap the whole lot + in a `some_fn`, so that we're in a local scope. Python treats functions at the + global scope differently, and this conversion won't result in any global + variable being treated as part of the pytree. + + In practice, the intended use case of this function is to fix Optax, which + always uses local functions. + """ + return jtu.tree_map(_fixup_closure, tree) diff --git a/tests/requirements.txt b/tests/requirements.txt index 6147e392..946ff033 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,4 @@ jaxlib +optax pytest beartype diff --git a/tests/test_closure_to_pytree.py b/tests/test_closure_to_pytree.py new file mode 100644 index 00000000..c7b8202c --- /dev/null +++ b/tests/test_closure_to_pytree.py @@ -0,0 +1,45 @@ +import jax.numpy as jnp +import jax.tree_util as jtu +import optax + +import equinox as eqx +import equinox.internal as eqxi + + +def test_fixup_optax(): + lr = jnp.array(1e-3) + optim = optax.chain( + optax.adam(lr), + optax.scale_by_schedule(optax.piecewise_constant_schedule(1, {200: 0.1})), + ) + optim = eqxi.closure_to_pytree(optim) + + for leaf in jtu.tree_leaves(optim): + if eqx.is_array(leaf) and leaf == -lr: + break + else: + assert False + + # Check that we can still init and update as normal. + grads = params = {"foo": jnp.array(1.0)} + state = optim.init(params) + optim.update(grads, state) + + lr = jnp.array(1e-2) + optim2 = optax.chain( + optax.adam(lr), + optax.scale_by_schedule(optax.piecewise_constant_schedule(1, {200: 0.1})), + ) + optim2 = eqxi.closure_to_pytree(optim2) + + compiling = 0 + + @eqx.filter_jit + def f(x): + nonlocal compiling + compiling += 1 + + f(optim) + assert compiling == 1 + f(optim2) + assert compiling == 1