Skip to content

Commit

Permalink
Reinstate the call stack from eqn.source_info in the backward pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
emilyfertig committed Dec 21, 2024
1 parent 46f5a52 commit 02f23c3
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,15 @@ def write_primal(v, val):
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in, cts_in_avals)
else:
cts_out = get_primitive_transpose(eqn.primitive)(
cts_in, *invals, **eqn.params)
try:
cts_out = get_primitive_transpose(eqn.primitive)(
cts_in, *invals, **eqn.params)
except (FloatingPointError, ZeroDivisionError) as e:
msg = "When differentiating the code at the top of the callstack:"
if msg not in e.args[0]:
e.args = e.args[0] + f'\n{msg}',
e.args = e.args[0] + f'\n{source_info_util.summarize(eqn.source_info)}',
raise e from None
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
# FIXME: Some invars correspond to primals!
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
Expand Down

0 comments on commit 02f23c3

Please sign in to comment.