Skip to content

Commit

Permalink
Relax tolerance for LAX reduction test in float16.
Browse files Browse the repository at this point in the history
At `float16` precision, one LAX reduction test was found to be flaky, and disabled in #25443. This change re-enables that test with a slightly relaxed tolerance instead.

PiperOrigin-RevId: 706771186
  • Loading branch information
dfm authored and Google-ML-Automation committed Dec 16, 2024
1 parent b3177da commit ed4e982
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,6 @@ def np_fun(x):
))
def testReducerPromoteInt(self, name, rng_factory, shape, dtype, axis,
keepdims, initial, inexact, promote_integers):
if jtu.test_device_matches(["cpu"]) and name == "sum" and config.enable_x64.value and dtype == np.float16:
raise unittest.SkipTest("sum op fails in x64 mode on CPU with dtype=float16") # b/383756018
np_op = getattr(np, name)
jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng())
Expand All @@ -364,7 +362,7 @@ def np_fun(x):
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol = {jnp.bfloat16: 3E-2}
tol = {jnp.bfloat16: 3E-2, jnp.float16: 5e-3}
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)

Expand Down

0 comments on commit ed4e982

Please sign in to comment.