diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index c2cf2e0dcb77..624dff8ebb98 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -330,8 +330,15 @@ def write_primal(v, val): cts_out = get_primitive_transpose(eqn.primitive)( params, call_jaxpr, invals, cts_in, cts_in_avals) else: - cts_out = get_primitive_transpose(eqn.primitive)( - cts_in, *invals, **eqn.params) + try: + cts_out = get_primitive_transpose(eqn.primitive)( + cts_in, *invals, **eqn.params) + except (FloatingPointError, ZeroDivisionError) as e: + msg = "When differentiating the code at the top of the callstack:" + if msg not in e.args[0]: + e.args = e.args[0] + f'\n{msg}', + e.args = e.args[0] + f'\n{source_info_util.summarize(eqn.source_info)}', + raise e from None cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)