Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix debug_nans regressions. #25519

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Conversation

emilyfertig
Copy link
Collaborator

@emilyfertig emilyfertig commented Dec 16, 2024

Fixes #25299

With this fix, debug_nans again reports the line where the NaN first appeared, including in reverse-mode autodiff and inside of pmap/shard_map. The "de-optimized function did not produce invalid values..." message again only appears when it's true.

The approach is to raise an InternalFloatingPointError when a NaN is detected in output, and catch those exceptions at the spot where we can run the de-optimized function.

Future improvements are tracked in #25643.

Example output:

import jax
import jax.numpy as jnp
import traceback

# (1) NaNs on the forward pass
@jax.jit
def f(x):
  y = jnp.square(x)
  return jnp.log(-y)

# (2) NaNs on the forward pass but nested
@jax.jit
def g(x):
  return f(x - 2.)

x = jnp.array([2., 0.])
z = jnp.zeros_like(x)

# (3) NaNs on the backward pass
out, f_vjp = jax.vjp(f, x)
# (4) ...and nested.
out, g_vjp = jax.vjp(g, x)

# (5) NaNs in forward autodiff
f_jvp = lambda: jax.jvp(f, [z], [jnp.ones_like(x)])

# (6) Grad of pmap
_, f_vjp_pmap = jax.vjp(jax.pmap(f), jnp.zeros([1]))

# (7) Grad of shard map
P = jax.sharding.PartitionSpec
mesh = jax.make_mesh((1,), ('x',))
shmap_f = jax.experimental.shard_map.shard_map(f, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
_, f_vjp_shmap = jax.vjp(shmap_f, jnp.zeros([1]))

with jax.debug_nans(True):

  one = jnp.ones([1])
  fns = [lambda: f(x), lambda: g(x), lambda: f_vjp(x), lambda: g_vjp(x), 
         lambda: f_jvp(), lambda: f_vjp_pmap(one), lambda: f_vjp_shmap(one)]
  names = ['f', 'g', 'f_vjp', 'g_vjp', 'f_jvp', 'f_vjp_pmap', 'f_vjp_shmap']

  for fun, n in zip(fns, names):
    try:
      fun().block_until_ready()
    except FloatingPointError as e:
      print(f'---------------------- {n} ---------------------')
      print(traceback.format_exc())
      print("\n\n\n")
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
---------------------- f ---------------------
Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 45, in <module>
    fun().block_until_ready()
    ^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 39, in <lambda>
    fns = [lambda: f(x), lambda: g(x), lambda: f_vjp(x), lambda: g_vjp(x),
                  ^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 9, in f
    return jnp.log(-y)
           ^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/numpy/ufuncs.py", line 489, in log
    return lax.log(*promote_args_inexact('log', x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in log
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.





Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
---------------------- g ---------------------
Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 45, in <module>
    fun().block_until_ready()
    ^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 39, in <lambda>
    fns = [lambda: f(x), lambda: g(x), lambda: f_vjp(x), lambda: g_vjp(x),
                                ^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 14, in g
    return f(x - 2.)
           ^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 9, in f
    return jnp.log(-y)
           ^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/numpy/ufuncs.py", line 489, in log
    return lax.log(*promote_args_inexact('log', x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in log
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.





Invalid nan value encountered in the backward pass of a C++-jit/pmap function. Calling the de-optimized backward pass.
---------------------- f_vjp ---------------------
Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 20, in <module>
    out, f_vjp = jax.vjp(f, x)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in div
When differentiating the code at the top of the callstack:
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:9:9 (f)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:20:13 (<module>)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 45, in <module>
    fun().block_until_ready()
    ^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 39, in <lambda>
    fns = [lambda: f(x), lambda: g(x), lambda: f_vjp(x), lambda: g_vjp(x),
                                              ^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/tree_util.py", line 477, in __call__
    return self.fun(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in div
When differentiating the code at the top of the callstack:
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:9:9 (f)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:20:13 (<module>)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.





Invalid nan value encountered in the backward pass of a C++-jit/pmap function. Calling the de-optimized backward pass.
Invalid nan value encountered in the backward pass of a C++-jit/pmap function. Calling the de-optimized backward pass.
---------------------- g_vjp ---------------------
Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 22, in <module>
    out, g_vjp = jax.vjp(g, x)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in mul
When differentiating the code at the top of the callstack:
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:8:6 (f)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:14:9 (g)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:22:13 (<module>)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 45, in <module>
    fun().block_until_ready()
    ^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 39, in <lambda>
    fns = [lambda: f(x), lambda: g(x), lambda: f_vjp(x), lambda: g_vjp(x),
                                                                ^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/tree_util.py", line 477, in __call__
    return self.fun(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in mul
When differentiating the code at the top of the callstack:
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:8:6 (f)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:14:9 (g)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:22:13 (<module>)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.





Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
---------------------- f_jvp ---------------------
Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 45, in <module>
    fun().block_until_ready()
    ^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 40, in <lambda>
    lambda: f_jvp(), lambda: f_vjp_pmap(one), lambda: f_vjp_shmap(one)]
            ^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 25, in <lambda>
    f_jvp = lambda: jax.jvp(f, [z], [jnp.ones_like(x)])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 9, in f
    return jnp.log(-y)
           ^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/numpy/ufuncs.py", line 489, in log
    return lax.log(*promote_args_inexact('log', x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in div
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.





Invalid nan value encountered in the backward pass of a C++-jit/pmap function. Calling the de-optimized backward pass.
Invalid nan value encountered in the backward pass of a C++-jit/pmap function. Calling the de-optimized backward pass.
---------------------- f_vjp_pmap ---------------------
Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 28, in <module>
    _, f_vjp_pmap = jax.vjp(jax.pmap(f), jnp.zeros([1]))
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in mul
When differentiating the code at the top of the callstack:
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:8:6 (f)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:28:16 (<module>)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 45, in <module>
    fun().block_until_ready()
    ^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 40, in <lambda>
    lambda: f_jvp(), lambda: f_vjp_pmap(one), lambda: f_vjp_shmap(one)]
                             ^^^^^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/tree_util.py", line 477, in __call__
    return self.fun(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in mul
When differentiating the code at the top of the callstack:
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:8:6 (f)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:28:16 (<module>)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.





---------------------- f_vjp_shmap ---------------------
Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 34, in <module>
    _, f_vjp_shmap = jax.vjp(shmap_f, jnp.zeros([1]))
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: Invalid value (nan) encountered in sharded computation.
When differentiating the code at the top of the callstack:
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:34:17 (<module>)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:34:17 (<module>)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 45, in <module>
    fun().block_until_ready()
    ^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 40, in <lambda>
    lambda: f_jvp(), lambda: f_vjp_pmap(one), lambda: f_vjp_shmap(one)]
                                                      ^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/tree_util.py", line 477, in __call__
    return self.fun(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: Invalid value (nan) encountered in sharded computation.
When differentiating the code at the top of the callstack:
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:34:17 (<module>)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:34:17 (<module>)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

@emilyfertig emilyfertig requested a review from mattjj December 16, 2024 21:02
@emilyfertig
Copy link
Collaborator Author

Update: I noticed that grad(shmap(...)) wasn't printing the line on which the NaN occurred, so the latest commit fixes that (though I'm not sure if it's the best fix). The error message is now:

---------------------- f_vjp_shmap ---------------------
Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 34, in <module>
    _, f_vjp_shmap = jax.vjp(shmap_f, jnp.zeros([1]))
jax._src.source_info_util.JaxStackTraceBeforeTransformation: FloatingPointError: invalid value (nan) encountered in mul
When differentiating the code at the top of the callstack:
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:8:6 (f)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:34:17 (<module>)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:34:17 (<module>)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 52, in <module>
    fun().block_until_ready()
    ^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 46, in <lambda>
    lambda: f_jvp(), lambda: f_vjp_pmap(one), lambda: f_vjp_shmap(one),
                                                      ^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/tree_util.py", line 477, in __call__
    return self.fun(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in mul
When differentiating the code at the top of the callstack:
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:8:6 (f)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:34:17 (<module>)
/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py:34:17 (<module>)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

I'm not sure why line 34:17 appears twice.

@emilyfertig
Copy link
Collaborator Author

There's also something weird with pmap:

@jax.jit
def f(x):
  y = jnp.square(x)
  return jnp.log(-y)

f_pmap = jax.pmap(f)

with jax.debug_nans(True):
  one = jnp.ones([1])
  f_pmap(jnp.zeros([1]))  # valid
  f_pmap(one) # invalid

With the fast dispatch path, we get an error pointing to log:

Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 45, in <module>
    f_pmap(one) # invalid
    ^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 9, in f
    return jnp.log(-y)
           ^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/numpy/ufuncs.py", line 489, in log
    return lax.log(*promote_args_inexact('log', x))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FloatingPointError: invalid value (nan) encountered in log
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

but with the slow path, we just get "parallel computation":

...
with jax.debug_nans(True):
  one = jnp.ones([1])
  # f_pmap(jnp.zeros([1]))  # valid
  f_pmap(one) # invalid
Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/api.py", line 1555, in cache_miss
    out = execute(*p.flat_args)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/interpreters/pxla.py", line 1303, in __call__
    dispatch.check_special(self.name, arrays)
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/dispatch.py", line 318, in check_special
    _check_special(name, buf.dtype, buf)
  File "/usr/local/google/home/emilyaf/jax-fork/jax/jax/_src/dispatch.py", line 323, in _check_special
    raise InternalFloatingPointError(name, "nan")
jax._src.dispatch.InternalFloatingPointError: ('parallel computation', 'nan')

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/google/home/emilyaf/jax-fork/jax/copy_debug_nans.py", line 45, in <module>
    f_pmap(one) # invalid
    ^^^^^^^^^^^
FloatingPointError: Invalid value (nan) encountered in parallel computation.

For shard_map, the traceback only reports "sharded computation" and not the primitive. I'll look at it more later but let me know if you have ideas for a fix @mattjj

jax/_src/dispatch.py Outdated Show resolved Hide resolved
jax/_src/dispatch.py Outdated Show resolved Hide resolved
@emilyfertig emilyfertig mentioned this pull request Dec 20, 2024
5 tasks
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for doing this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Regressions in debug_nans (and debug_infs)
4 participants