diff --git a/equinox/_ad.py b/equinox/_ad.py index 374347e4..665f7792 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -340,7 +340,9 @@ def filter_jvp( flat_tangents = jtu.tree_leaves(tangents) # all non-None tangents are dynamic def _fn(*_flat_dynamic): - _main = jax.core.find_top_trace(_flat_dynamic).main + _top_trace = jax.core.find_top_trace(_flat_dynamic) + assert _top_trace is not None + _main = _top_trace.main _dynamic = jtu.tree_unflatten(treedef, _flat_dynamic) _in = combine(_dynamic, static_primals) _out = fn(*_in, **kwargs) diff --git a/equinox/_vmap_pmap.py b/equinox/_vmap_pmap.py index 1b68afc6..ef917500 100644 --- a/equinox/_vmap_pmap.py +++ b/equinox/_vmap_pmap.py @@ -8,8 +8,6 @@ import jax import jax._src.traceback_util as traceback_util import jax.core -import jax.interpreters.batching as batching -import jax.interpreters.pxla as pxla import jax.numpy as jnp import jax.tree_util as jtu import numpy as np @@ -78,53 +76,6 @@ def __call__(self, x: Any) -> Optional[int]: return self.axis if is_array(x) else None -@dataclasses.dataclass(frozen=True) # not a pytree -class if_mapped: - """Used with the `out_axes` argument of [`equinox.filter_vmap`][], to only add an - output batch axis if necessary. - - That is, `out_axes=if_mapped(i)` is equivalent to `out_axes=i` for any output that - is batched, and `out_axes=None` fofr any output that is not batched. - """ - - axis: int - - def __call__(self, x: Any): - raise RuntimeError( - "`eqx.internal.if_mapped` should not be called directly; it is only valid " - "when passed to `out_axes` of `eqx.filter_vmap`." - ) - - -@dataclasses.dataclass(frozen=True) # not a pytree -class _if_mapped: - main: Any - axis: int - - def __call__(self, x: Any) -> Optional[int]: - if isinstance(x, batching.BatchTracer) and x._trace.main is self.main: - if x.batch_dim is batching.not_mapped: - return None - else: - return self.axis - elif isinstance(x, pxla.MapTracer) and x._trace.main is self.main: - return self.axis - else: - return None - - -# The existence of this function is a complete hack: it couples together `filter_vmap` -# with `if_mapped`. I don't see an obvious way around it though. -def _bind_main(main, out_axes): - def _bind(axis): - if isinstance(axis, if_mapped): - return _if_mapped(main, axis.axis) - else: - return axis - - return jtu.tree_map(_bind, out_axes) - - def _moveaxis(array, axis): return jnp.moveaxis(array, 0, axis) @@ -199,11 +150,9 @@ def __call__(self, /, *args, **kwargs): static_args, dynamic_args = partition(args, unmapped_axis) def _fun_wrapper(_dynamic_args): - _main = jax.core.find_top_trace(jtu.tree_leaves(_dynamic_args)).main _args = combine(_dynamic_args, static_args) _out = self._fun(*_args) - _out_axes = _bind_main(_main, self._out_axes) - _out_axes = _resolve_axes(_out, _out_axes) + _out_axes = _resolve_axes(_out, self._out_axes) _none_axes = jtu.tree_map(_is_none, _out_axes, is_leaf=_is_none) _nonvmapd, _vmapd = partition(_out, _none_axes, is_leaf=_is_none) _nonvmapd_arr, _nonvmapd_static = partition(_nonvmapd, is_array) @@ -235,6 +184,7 @@ def _fun_wrapper(_dynamic_args): return combine(vmapd, nonvmapd) def __get__(self, instance, owner): + del owner if instance is None: return self return Partial(self, instance) @@ -439,10 +389,8 @@ def _check_map_out_axis(x: Optional[int]): ) def fun_wrapped(_dynamic): - _main = jax.core.find_top_trace(jtu.tree_leaves(_dynamic)) _fun, _args, _, _out_axes = combine(_dynamic, static) _out = _fun(*_args) - _out_axes = _bind_main(_main, _out_axes) _out_axes = _resolve_axes(_out, _out_axes) jtu.tree_map(_check_map_out_axis, _out_axes) _pmapd = [] @@ -558,6 +506,7 @@ def lower(self, /, *args, **kwargs) -> Lowered: return self._call(True, args, kwargs) def __get__(self, instance, owner): + del owner if instance is None: return self return Partial(self, instance) diff --git a/equinox/internal/__init__.py b/equinox/internal/__init__.py index 75a60cdd..0fb2f266 100644 --- a/equinox/internal/__init__.py +++ b/equinox/internal/__init__.py @@ -30,7 +30,6 @@ unvmap_max as unvmap_max, unvmap_max_p as unvmap_max_p, ) -from .._vmap_pmap import if_mapped as if_mapped # Backward compatibility: expose via `equinox.internal`. Now available under # `equinox.debug`. diff --git a/tests/test_pmap.py b/tests/test_pmap.py index 0635d91c..8e526ce4 100644 --- a/tests/test_pmap.py +++ b/tests/test_pmap.py @@ -272,23 +272,6 @@ def f(x, y): compiled(x, y) -def test_double_if_mapped(): - out_axes = eqx.internal.if_mapped(1) - - def f(x): - assert x.shape == (3, 1) - - def g(y): - assert y.shape == (1,) - return y + 1, x + 1 - - a, b = eqx.filter_vmap(g, out_axes=out_axes)(x) - assert a.shape == (1, 3) - assert b.shape == (3, 1) - - filter_pmap(f)(jnp.arange(3).reshape(1, 3, 1)) - - # https://github.com/patrick-kidger/equinox/issues/900 # Unlike the vmap case we only test nonnegative integers, as pmap does not support # negative indexing for `in_axes` or `out_axes`. diff --git a/tests/test_vmap.py b/tests/test_vmap.py index ba196be3..d3facaf6 100644 --- a/tests/test_vmap.py +++ b/tests/test_vmap.py @@ -160,23 +160,6 @@ def test_keyword_default(getkey): eqx.filter_vmap(lambda x, y=1: x, in_axes=dict(y=0))(x) -def test_double_if_mapped(): - out_axes = eqx.internal.if_mapped(1) - - def f(x): - assert x.shape == (3, 1) - - def g(y): - assert y.shape == (1,) - return y + 1, x + 1 - - a, b = eqx.filter_vmap(g, out_axes=out_axes)(x) - assert a.shape == (1, 3) - assert b.shape == (3, 1) - - eqx.filter_vmap(f)(jnp.arange(6).reshape(2, 3, 1)) - - # https://github.com/patrick-kidger/equinox/issues/900 @pytest.mark.parametrize("out_axes", (0, 1, 2, -1, -2, -3)) def test_out_axes_with_at_least_three_dimensions(out_axes):