From c9b0d1f3cacb158808158d07a0fd29107bf919cc Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 7 Oct 2024 19:35:41 +0200 Subject: [PATCH] More fixes for JAX 0.4.34 --- equinox/_ad.py | 10 +++++++--- equinox/_jit.py | 4 ++-- equinox/internal/_primitive.py | 9 ++++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/equinox/_ad.py b/equinox/_ad.py index 2aaac6af..374347e4 100644 --- a/equinox/_ad.py +++ b/equinox/_ad.py @@ -878,9 +878,13 @@ def _none_to_zero(ct, x): if x is None: return None else: - # No raising-to-vspace. JAX is internally inconsistent, and expects integers - # to have integer tangents from custom_{jvp,vjp} rules - aval = jax.core.raise_to_shaped(jax.core.get_aval(x)) # .at_least_vspace() + aval = jax.core.raise_to_shaped(jax.core.get_aval(x)) + if hasattr(aval, "to_tangent_aval"): + # Earlier versions of JAX were internally inconsistent, and expected + # e.g. integer primals to have integer tangents from `custom_{jvp,vjp}` + # rules. + # That changed in JAX 0.4.34. + aval = aval.to_tangent_aval() # pyright: ignore return jax.custom_derivatives.SymbolicZero(aval) else: return ct diff --git a/equinox/_jit.py b/equinox/_jit.py index b1eea161..1d6b190a 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -110,11 +110,11 @@ def _postprocess(out): try: # Added in JAX 0.4.34. - JaxRuntimeError = jax.errors.JaxRuntimeError # pyright: ignore[reportAttributeAccessIssue] + JaxRuntimeError = jax.errors.JaxRuntimeError # pyright: ignore except AttributeError: try: # Forward compatibility in case they ever decide to fix the capitalization. - JaxRuntimeError = jax.errors.JAXRuntimeError # pyright: ignore[reportAttributeAccessIssue] + JaxRuntimeError = jax.errors.JAXRuntimeError # pyright: ignore except AttributeError: # Not public API, so wrap in a try-except for forward compatibility. try: diff --git a/equinox/internal/_primitive.py b/equinox/internal/_primitive.py index 91ffa0e5..c2d63f81 100644 --- a/equinox/internal/_primitive.py +++ b/equinox/internal/_primitive.py @@ -59,7 +59,14 @@ def _is_array_like_internal(x): def _zero_from_primal(p): assert type(p) is not ad.UndefinedPrimal - return ad.Zero(jax.core.get_aval(p).at_least_vspace()) + aval = jax.core.get_aval(p) + if hasattr(aval, "to_tangent_aval"): + # JAX >=0.4.34 + aval = aval.to_tangent_aval() # pyright: ignore + else: + # earlier JAX + aval = aval.at_least_vspace() + return ad.Zero(aval) def _combine(dynamic, static):