Skip to content

Commit

Permalink
More fixes for JAX 0.4.34
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Oct 8, 2024
1 parent 9121934 commit c9b0d1f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
10 changes: 7 additions & 3 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,9 +878,13 @@ def _none_to_zero(ct, x):
if x is None:
return None
else:
# No raising-to-vspace. JAX is internally inconsistent, and expects integers
# to have integer tangents from custom_{jvp,vjp} rules
aval = jax.core.raise_to_shaped(jax.core.get_aval(x)) # .at_least_vspace()
aval = jax.core.raise_to_shaped(jax.core.get_aval(x))
if hasattr(aval, "to_tangent_aval"):
# Earlier versions of JAX were internally inconsistent, and expected
# e.g. integer primals to have integer tangents from `custom_{jvp,vjp}`
# rules.
# That changed in JAX 0.4.34.
aval = aval.to_tangent_aval() # pyright: ignore
return jax.custom_derivatives.SymbolicZero(aval)
else:
return ct
Expand Down
4 changes: 2 additions & 2 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def _postprocess(out):

try:
# Added in JAX 0.4.34.
JaxRuntimeError = jax.errors.JaxRuntimeError # pyright: ignore[reportAttributeAccessIssue]
JaxRuntimeError = jax.errors.JaxRuntimeError # pyright: ignore
except AttributeError:
try:
# Forward compatibility in case they ever decide to fix the capitalization.
JaxRuntimeError = jax.errors.JAXRuntimeError # pyright: ignore[reportAttributeAccessIssue]
JaxRuntimeError = jax.errors.JAXRuntimeError # pyright: ignore
except AttributeError:
# Not public API, so wrap in a try-except for forward compatibility.
try:
Expand Down
9 changes: 8 additions & 1 deletion equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@ def _is_array_like_internal(x):

def _zero_from_primal(p):
assert type(p) is not ad.UndefinedPrimal
return ad.Zero(jax.core.get_aval(p).at_least_vspace())
aval = jax.core.get_aval(p)
if hasattr(aval, "to_tangent_aval"):
# JAX >=0.4.34
aval = aval.to_tangent_aval() # pyright: ignore
else:
# earlier JAX
aval = aval.at_least_vspace()
return ad.Zero(aval)


def _combine(dynamic, static):
Expand Down

0 comments on commit c9b0d1f

Please sign in to comment.