-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix debug_nans regressions. #25519
base: main
Are you sure you want to change the base?
Fix debug_nans regressions. #25519
Conversation
Update: I noticed that grad(shmap(...)) wasn't printing the line on which the NaN occurred, so the latest commit fixes that (though I'm not sure if it's the best fix). The error message is now:
I'm not sure why line 34:17 appears twice. |
There's also something weird with pmap: @jax.jit
def f(x):
y = jnp.square(x)
return jnp.log(-y)
f_pmap = jax.pmap(f)
with jax.debug_nans(True):
one = jnp.ones([1])
f_pmap(jnp.zeros([1])) # valid
f_pmap(one) # invalid With the fast dispatch path, we get an error pointing to log:
but with the slow path, we just get "parallel computation": ...
with jax.debug_nans(True):
one = jnp.ones([1])
# f_pmap(jnp.zeros([1])) # valid
f_pmap(one) # invalid
For shard_map, the traceback only reports "sharded computation" and not the primitive. I'll look at it more later but let me know if you have ideas for a fix @mattjj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks for doing this!
b7fd036
to
46f5a52
Compare
Co-authored-by: Matt Johnson <[email protected]>
02f23c3
to
023ad2e
Compare
Fixes #25299
With this fix,
debug_nans
again reports the line where the NaN first appeared, including in reverse-mode autodiff and inside of pmap/shard_map. The "de-optimized function did not produce invalid values..." message again only appears when it's true.The approach is to raise an InternalFloatingPointError when a NaN is detected in output, and catch those exceptions at the spot where we can run the de-optimized function.
Future improvements are tracked in #25643.
Example output: