Skip to content

Commit

Permalink
Address reviewer comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
emilyfertig committed Dec 17, 2024
1 parent fa40530 commit 3f4e54b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def prim_fun(*args):
return prim.bind(*args, **params)
prim_fun.__name__ = prim.name
prim_fun.__qualname__ = prim.name
prim_fun.__is_primitive__ = True
prim_fun._apply_primitive = True
return api.jit(prim_fun)


Expand Down Expand Up @@ -339,7 +339,7 @@ def __init__(self, name: str, ty: str):

def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs,
) -> None: # always raises an exception
print("Invalid nan value encountered in the output of a C++-jit/pmap "
print("Invalid nan value encountered in the output of a jax.jit "
"function. Calling the de-optimized version.")
try:
_ = fun(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def out_axes_thunk():
try:
out_flat = primitive.bind(fun, *all_args, **new_params)
except dispatch.InternalFloatingPointError as e:
print("Invalid nan value encountered in the backward pass of a C++-jit/pmap "
print("Invalid nan value encountered in the backward pass of a jax.jit "
"function. Calling the de-optimized backward pass.")
try:
_ = backward_pass(call_jaxpr, None, {}, args, ct)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs):
f' {type(arg)} is not a valid JAX type.') from e
raise AssertionError("Unreachable") from e
except dispatch.InternalFloatingPointError as e:
if getattr(fun, '__is_primitive__', False):
if getattr(fun, '_apply_primitive', False):
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None
dispatch.maybe_recursive_nan_check(e, fun, args, kwargs)

Expand Down Expand Up @@ -2371,7 +2371,7 @@ def prune_type(ty, xs, maybe_zeros):
inline=inline,
compiler_options_kvs=compiler_options_kvs)
except dispatch.InternalFloatingPointError as e:
print("Invalid nan value encountered in the backward pass of a C++-jit/pmap "
print("Invalid nan value encountered in the backward pass of a jax.jit "
"function. Calling the de-optimized backward pass.")
try:
_ = ad.closed_backward_pass(jaxpr, None, primals_in, cts_in)
Expand Down

0 comments on commit 3f4e54b

Please sign in to comment.