Skip to content

Commit

Permalink
Added equinox.internal.closure_to_pytree
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 7, 2023
1 parent e3b79c4 commit 617e043
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
rev: v1.1.315
hooks:
- id: pyright
additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, tensorflow, tf2onnx, typing_extensions]
additional_dependencies: [beartype, einops, jax, jaxtyping, optax, pytest, tensorflow, tf2onnx, typing_extensions]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.3
hooks:
Expand Down
1 change: 1 addition & 0 deletions equinox/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
store_dce as store_dce,
)
from ..debug._announce_transform import announce_jaxpr_p as announce_jaxpr_p
from ._closure_to_pytree import closure_to_pytree as closure_to_pytree
from ._finalise_jaxpr import (
finalise_eval_jaxpr as finalise_eval_jaxpr,
finalise_fn as finalise_fn,
Expand Down
125 changes: 125 additions & 0 deletions equinox/internal/_closure_to_pytree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# This is some mildly unpleasant code.
#
# Basically, Optax make the decision *not* to register their optimisers as PyTrees.
# This means that we often end up with spurious recompilation, just because a learning
# rate changed. That results in a new optimiser instance, which is just a function and
# is treated statically.
#
# So here we simply replace all function closures with pytrees, with each of their cell
# contents as their subnodes.

import types
from typing import Any, Optional

import jax.tree_util as jtu

from .._module import Module


def _make_cell(val):
fn = lambda: val
return fn.__closure__[0] # pyright: ignore


def _adjust_function_closure(fn, closure):
out = types.FunctionType(
code=fn.__code__,
globals=fn.__globals__,
name=fn.__name__,
argdefs=fn.__defaults__,
closure=closure,
)
out.__module__ = fn.__module__
out.__qualname__ = fn.__qualname__
out.__doc__ = fn.__doc__
out.__annotations__.update(fn.__annotations__)
if fn.__kwdefaults__ is not None:
out.__kwdefaults__ = fn.__kwdefaults__.copy()
return out


# Not a pytree.
# Used so that two different local functions, with different identities, can still
# compare equal. This is needed as these leaves are compared statically when
# filter-jit'ing.
class _FunctionWithEquality:
def __init__(self, fn: types.FunctionType):
self.fn = fn

def information(self):
return self.fn.__qualname__, self.fn.__module__

def __hash__(self):
return hash(self.information())

def __eq__(self, other):
return type(self) == type(other) and self.information() == other.information()


class _Closure(Module):
fn: _FunctionWithEquality
contents: Optional[tuple[Any, ...]]

def __init__(self, fn: types.FunctionType):
self.fn = _FunctionWithEquality(fn)
if fn.__closure__ is None:
contents = None
else:
contents = tuple(
closure_to_pytree(cell.cell_contents) for cell in fn.__closure__
)
self.contents = contents

def __call__(self, *args, **kwargs):
if self.contents is None:
closure = None
else:
closure = tuple(_make_cell(contents) for contents in self.contents)
fn = _adjust_function_closure(self.fn.fn, closure)
return fn(*args, **kwargs)


def _fixup_closure(leaf):
if isinstance(leaf, types.FunctionType):
return _Closure(leaf)
else:
return leaf


def closure_to_pytree(tree):
"""Convert all function closures into pytree nodes.
**Arguments:**
- `tree`: Any pytree.
**Returns:**
A copy of `tree`, where all function closures have been replaced by a new object
that is (a) callable like the original function, but (b) iterates over its
`__closure__` as subnodes in the pytree.
!!! Example
```python
def some_fn():
a = jnp.array(1.)
@closure_to_pytree
def f(x):
return x + a
print(jax.tree_util.tree_leaves(f)) # prints out `a`
```
!!! Warning
One annoying technical detail in the above example: we had to wrap the whole lot
in a `some_fn`, so that we're in a local scope. Python treats functions at the
global scope differently, and this conversion won't result in any global
variable being treated as part of the pytree.
In practice, the intended use case of this function is to fix Optax, which
always uses local functions.
"""
return jtu.tree_map(_fixup_closure, tree)
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
jaxlib
optax
pytest
beartype
45 changes: 45 additions & 0 deletions tests/test_closure_to_pytree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import optax

import equinox as eqx
import equinox.internal as eqxi


def test_fixup_optax():
lr = jnp.array(1e-3)
optim = optax.chain(
optax.adam(lr),
optax.scale_by_schedule(optax.piecewise_constant_schedule(1, {200: 0.1})),
)
optim = eqxi.closure_to_pytree(optim)

for leaf in jtu.tree_leaves(optim):
if eqx.is_array(leaf) and leaf == -lr:
break
else:
assert False

# Check that we can still init and update as normal.
grads = params = {"foo": jnp.array(1.0)}
state = optim.init(params)
optim.update(grads, state)

lr = jnp.array(1e-2)
optim2 = optax.chain(
optax.adam(lr),
optax.scale_by_schedule(optax.piecewise_constant_schedule(1, {200: 0.1})),
)
optim2 = eqxi.closure_to_pytree(optim2)

compiling = 0

@eqx.filter_jit
def f(x):
nonlocal compiling
compiling += 1

f(optim)
assert compiling == 1
f(optim2)
assert compiling == 1

0 comments on commit 617e043

Please sign in to comment.