Skip to content

Commit

Permalink
More fixes for JAX 0.4.34
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Oct 7, 2024
1 parent 9121934 commit f4b723f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
10 changes: 7 additions & 3 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,9 +878,13 @@ def _none_to_zero(ct, x):
if x is None:
return None
else:
# No raising-to-vspace. JAX is internally inconsistent, and expects integers
# to have integer tangents from custom_{jvp,vjp} rules
aval = jax.core.raise_to_shaped(jax.core.get_aval(x)) # .at_least_vspace()
aval = jax.core.raise_to_shaped(jax.core.get_aval(x))
if hasattr(aval, "to_tangent_aval"):
# Earlier versions of JAX were internally inconsistent, and expected
# e.g. integer primals to have integer tangents from `custom_{jvp,vjp}`
# rules.
# That changed in JAX 0.4.34.
aval = aval.to_tangent_aval() # pyright: ignore
return jax.custom_derivatives.SymbolicZero(aval)
else:
return ct
Expand Down
9 changes: 8 additions & 1 deletion equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@ def _is_array_like_internal(x):

def _zero_from_primal(p):
assert type(p) is not ad.UndefinedPrimal
return ad.Zero(jax.core.get_aval(p).at_least_vspace())
aval = jax.core.get_aval(p)
if hasattr(aval, "to_tangent_aval"):
# JAX >=0.4.34
aval = aval.to_tangent_aval() # pyright: ignore
else:
# earlier JAX
aval = aval.at_least_vspace()
return ad.Zero(aval)


def _combine(dynamic, static):
Expand Down

0 comments on commit f4b723f

Please sign in to comment.