Skip to content

Commit

Permalink
More linearize fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Dec 14, 2024
1 parent f4e5f14 commit 00e9918
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,10 @@ def direct_linearize(traceable, primals, kwargs, *, has_aux=False, tag=None):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace()
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
tracers = [t.full_lower() for t in tracers]
with core.set_current_trace(linearize_trace):
if has_aux:
ans, aux = traceable.call_wrapped(*tracers)
Expand Down Expand Up @@ -622,8 +624,8 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [core.get_aval(x).to_tangent_aval() for x in primals_out]

tangents_in = map(instantiate_zeros, tangents_in)
with core.set_current_trace(self.tangent_trace):
tangents_in = map(instantiate_zeros, tangents_in)
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
Expand Down

0 comments on commit 00e9918

Please sign in to comment.