Skip to content

Commit

Permalink
Remove if_mapped to skip a usage of find_top_trace.
Browse files Browse the repository at this point in the history
We still have one in `filter_jvp` that might need some later resolution, as JAX now has a comment indicating that they intend to remove this function.

`if_mapped` was never documented (and frankly not that useful either) so I'm treating its removal here as a non-breaking change.
  • Loading branch information
patrick-kidger committed Dec 8, 2024
1 parent 09530f1 commit d7d2cb9
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 90 deletions.
4 changes: 3 additions & 1 deletion equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ def filter_jvp(
flat_tangents = jtu.tree_leaves(tangents) # all non-None tangents are dynamic

def _fn(*_flat_dynamic):
_main = jax.core.find_top_trace(_flat_dynamic).main
_top_trace = jax.core.find_top_trace(_flat_dynamic)
assert _top_trace is not None
_main = _top_trace.main
_dynamic = jtu.tree_unflatten(treedef, _flat_dynamic)
_in = combine(_dynamic, static_primals)
_out = fn(*_in, **kwargs)
Expand Down
57 changes: 3 additions & 54 deletions equinox/_vmap_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.interpreters.batching as batching
import jax.interpreters.pxla as pxla
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
Expand Down Expand Up @@ -78,53 +76,6 @@ def __call__(self, x: Any) -> Optional[int]:
return self.axis if is_array(x) else None


@dataclasses.dataclass(frozen=True) # not a pytree
class if_mapped:
"""Used with the `out_axes` argument of [`equinox.filter_vmap`][], to only add an
output batch axis if necessary.
That is, `out_axes=if_mapped(i)` is equivalent to `out_axes=i` for any output that
is batched, and `out_axes=None` fofr any output that is not batched.
"""

axis: int

def __call__(self, x: Any):
raise RuntimeError(
"`eqx.internal.if_mapped` should not be called directly; it is only valid "
"when passed to `out_axes` of `eqx.filter_vmap`."
)


@dataclasses.dataclass(frozen=True) # not a pytree
class _if_mapped:
main: Any
axis: int

def __call__(self, x: Any) -> Optional[int]:
if isinstance(x, batching.BatchTracer) and x._trace.main is self.main:
if x.batch_dim is batching.not_mapped:
return None
else:
return self.axis
elif isinstance(x, pxla.MapTracer) and x._trace.main is self.main:
return self.axis
else:
return None


# The existence of this function is a complete hack: it couples together `filter_vmap`
# with `if_mapped`. I don't see an obvious way around it though.
def _bind_main(main, out_axes):
def _bind(axis):
if isinstance(axis, if_mapped):
return _if_mapped(main, axis.axis)
else:
return axis

return jtu.tree_map(_bind, out_axes)


def _moveaxis(array, axis):
return jnp.moveaxis(array, 0, axis)

Expand Down Expand Up @@ -199,11 +150,9 @@ def __call__(self, /, *args, **kwargs):
static_args, dynamic_args = partition(args, unmapped_axis)

def _fun_wrapper(_dynamic_args):
_main = jax.core.find_top_trace(jtu.tree_leaves(_dynamic_args)).main
_args = combine(_dynamic_args, static_args)
_out = self._fun(*_args)
_out_axes = _bind_main(_main, self._out_axes)
_out_axes = _resolve_axes(_out, _out_axes)
_out_axes = _resolve_axes(_out, self._out_axes)
_none_axes = jtu.tree_map(_is_none, _out_axes, is_leaf=_is_none)
_nonvmapd, _vmapd = partition(_out, _none_axes, is_leaf=_is_none)
_nonvmapd_arr, _nonvmapd_static = partition(_nonvmapd, is_array)
Expand Down Expand Up @@ -235,6 +184,7 @@ def _fun_wrapper(_dynamic_args):
return combine(vmapd, nonvmapd)

def __get__(self, instance, owner):
del owner
if instance is None:
return self
return Partial(self, instance)
Expand Down Expand Up @@ -439,10 +389,8 @@ def _check_map_out_axis(x: Optional[int]):
)

def fun_wrapped(_dynamic):
_main = jax.core.find_top_trace(jtu.tree_leaves(_dynamic))
_fun, _args, _, _out_axes = combine(_dynamic, static)
_out = _fun(*_args)
_out_axes = _bind_main(_main, _out_axes)
_out_axes = _resolve_axes(_out, _out_axes)
jtu.tree_map(_check_map_out_axis, _out_axes)
_pmapd = []
Expand Down Expand Up @@ -558,6 +506,7 @@ def lower(self, /, *args, **kwargs) -> Lowered:
return self._call(True, args, kwargs)

def __get__(self, instance, owner):
del owner
if instance is None:
return self
return Partial(self, instance)
Expand Down
1 change: 0 additions & 1 deletion equinox/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
unvmap_max as unvmap_max,
unvmap_max_p as unvmap_max_p,
)
from .._vmap_pmap import if_mapped as if_mapped

# Backward compatibility: expose via `equinox.internal`. Now available under
# `equinox.debug`.
Expand Down
17 changes: 0 additions & 17 deletions tests/test_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,23 +272,6 @@ def f(x, y):
compiled(x, y)


def test_double_if_mapped():
out_axes = eqx.internal.if_mapped(1)

def f(x):
assert x.shape == (3, 1)

def g(y):
assert y.shape == (1,)
return y + 1, x + 1

a, b = eqx.filter_vmap(g, out_axes=out_axes)(x)
assert a.shape == (1, 3)
assert b.shape == (3, 1)

filter_pmap(f)(jnp.arange(3).reshape(1, 3, 1))


# https://github.com/patrick-kidger/equinox/issues/900
# Unlike the vmap case we only test nonnegative integers, as pmap does not support
# negative indexing for `in_axes` or `out_axes`.
Expand Down
17 changes: 0 additions & 17 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,6 @@ def test_keyword_default(getkey):
eqx.filter_vmap(lambda x, y=1: x, in_axes=dict(y=0))(x)


def test_double_if_mapped():
out_axes = eqx.internal.if_mapped(1)

def f(x):
assert x.shape == (3, 1)

def g(y):
assert y.shape == (1,)
return y + 1, x + 1

a, b = eqx.filter_vmap(g, out_axes=out_axes)(x)
assert a.shape == (1, 3)
assert b.shape == (3, 1)

eqx.filter_vmap(f)(jnp.arange(6).reshape(2, 3, 1))


# https://github.com/patrick-kidger/equinox/issues/900
@pytest.mark.parametrize("out_axes", (0, 1, 2, -1, -2, -3))
def test_out_axes_with_at_least_three_dimensions(out_axes):
Expand Down

0 comments on commit d7d2cb9

Please sign in to comment.