diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index dd13ff2a4763..295fe9159f59 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -156,7 +156,7 @@ def f(x): with self.assertRaisesRegex( FloatingPointError, - r"Invalid value \(nan\) encountered in sharded computation.\nWhen differentiating"): + r"invalid value \(nan\) encountered in mul\nWhen differentiating"): ans, = f_vjp(jnp.ones([1])) ans.block_until_ready()