Skip to content

Commit

Permalink
Merge pull request #25481 from jax-ml:custom-linearize-process-call
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706035425
  • Loading branch information
Google-ML-Automation committed Dec 14, 2024
2 parents c73f306 + dea51cb commit f4e5f14
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 27 deletions.
103 changes: 88 additions & 15 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,27 @@ def jvpfun(f, instantiate, transform_stack, primals, tangents):
in zip(out_tangents, instantiate)]
return out_primals, out_tangents

@lu.transformation_with_aux2
def linearize_subtrace(_f, _store, _tag, nzs_in, *primals, **params):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace()
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval())
for (p, nz) in zip(primals, nzs_in) if nz]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
with core.set_current_trace(linearize_trace):
ans = _f(*tracers)
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
nzs_out = [type(t) is not Zero for t in out_tangents]
out_tangents = [t for t, nz in zip(out_tangents, nzs_out) if nz]
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
num_residuals = len(consts)
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
_store.store((num_residuals, nzs_out, jaxpr))
return tuple(consts) + tuple(out_primals)

@lu.transformation2
def jvp_subtrace(f, tag, primals, tangents):
with core.take_current_trace() as parent_trace:
Expand Down Expand Up @@ -133,20 +154,21 @@ def new_arg(primal_aval, nz):
out_tangents = [tangent_trace.to_jaxpr_tracer(t)
for (nz, t) in zip(nzs_out, out_tangents) if nz]
tangent_jaxpr, tangent_consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
del attrs_tracked # TODO: attrs
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
residuals_and_primals = (*tangent_consts, *out_primals)
primal_jaxpr, primal_consts, attrs_tracked = primal_trace.to_jaxpr(residuals_and_primals)
num_residuals = len(tangent_consts)
tangent_jaxpr = pe.close_jaxpr(convert_constvars_jaxpr_constvars_at_end(tangent_jaxpr))
del attrs_tracked # TODO: attrs
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
return core.ClosedJaxpr(primal_jaxpr, primal_consts), num_residuals, nzs_out, tangent_jaxpr

def direct_linearize(traceable, *primals, **kwargs):
has_aux = kwargs.pop('has_aux', False)
def direct_linearize(traceable, primals, kwargs, *, has_aux=False, tag=None):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace()
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace)
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
with core.set_current_trace(linearize_trace):
if has_aux:
Expand All @@ -163,16 +185,17 @@ def direct_linearize(traceable, *primals, **kwargs):
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
jaxpr, consts, attrs_tracked = tangent_trace.to_jaxpr(out_tangents)
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
del attrs_tracked # TODO: attrs
if attrs_tracked:
raise NotImplementedError("TODO: attrs")
if has_aux:
return out_primals, out_tangents_pvals, jaxpr, consts, aux_primals
else:
return out_primals, out_tangents_pvals, jaxpr, consts

def linearize(traceable, *primals, **kwargs):
if config.use_direct_linearize.value:
return direct_linearize(traceable, *primals, **kwargs)
has_aux = kwargs.pop('has_aux', False)
if config.use_direct_linearize.value:
return direct_linearize(traceable, primals, kwargs, has_aux=has_aux)
if not has_aux:
jvpfun = jvp(traceable)
else:
Expand Down Expand Up @@ -558,6 +581,31 @@ def process_primitive(self, primitive, args, params):
else:
return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out)

def process_custom_jvp_call(self, prim, fun, f_jvp, tracers, *, symbolic_zeros):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
if all(type(t) is Zero for t in tangents_in):
return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in),
dict(symbolic_zeros=symbolic_zeros))
with core.set_current_trace(self.parent_trace):
if not symbolic_zeros:
tangents_in = map(instantiate_zeros, tangents_in)
else:
tangents_in = map(replace_internal_symbolic_zeros, tangents_in)
nonzeros_in = [type(t) is not Zero for t in tangents_in]

def _f_jvp(primals, tangents):
outs = f_jvp.call_wrapped(*primals, *tangents)
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
return primals_out, tangents_out

primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp(
_f_jvp, True, nonzeros_in, primals_in, {})
with core.set_current_trace(self.tangent_trace):
tangents_out = linearized(residuals, *tangents_in)
tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out)
return [maybe_linearize_tracer(self, x, nz, t)
for x, nz, t in zip(primals_out, tangent_nzs_out, tangents_out)]

def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
Expand All @@ -582,6 +630,29 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
tangent_nzs_out = [type(t) is not Zero for t in tangents_out]
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)

def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
nzs_in = [type(t) is not Zero for t in tangents]
f_primal, linearize_outs_thunk = linearize_subtrace(f, self.tag, nzs_in)
all_primal_results = call_primitive.bind_with_trace(self.parent_trace, (f_primal, *primals), params)
num_residuals, nzs_out, lin_jaxpr = linearize_outs_thunk()
residuals = all_primal_results[:num_residuals]
primals_out = all_primal_results[num_residuals:]

def f_tangent(*args):
residuals = args[:num_residuals]
nz_tangents = args[num_residuals:]
return core.eval_jaxpr(lin_jaxpr, residuals, *nz_tangents)

nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = call_primitive.bind_with_trace(
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
return map(partial(maybe_linearize_tracer, self), primals_out, nzs_out, tangents_out)

def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
if is_nonzero:
assert not type(tangent) is Zero
Expand All @@ -590,11 +661,14 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
assert type(tangent) is Zero
return primal

def fallback_linearize_rule(prim, nonzeros, *primals, **params):
jvp = primitive_jvps.get(prim)
def fallback_linearize_rule(_prim, _nonzeros, *primals, **params):
jvp = primitive_jvps.get(_prim)
if not jvp:
msg = f"Differentiation rule for '{prim}' not implemented"
msg = f"Differentiation rule for '{_prim}' not implemented"
raise NotImplementedError(msg)
return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, primals, params)

def linearize_from_jvp(jvp, multiple_results, nonzeros, primals, params):
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag())
Expand All @@ -604,7 +678,7 @@ def fallback_linearize_rule(prim, nonzeros, *primals, **params):
with core.set_current_trace(trace):
out_primals, out_tangents = jvp(primals, tangent_args, **params)

if not prim.multiple_results:
if not multiple_results:
out_primals = [out_primals]
out_tangents = [out_tangents]

Expand All @@ -621,20 +695,19 @@ def linearized(residuals, *tangents):
nz_tangents_out_iter = iter(nz_tangents_out)
all_out_tangents = [next(nz_tangents_out_iter) if nz else Zero(aval)
for (aval, nz) in zip(out_tangent_avals, out_nzs)]
if prim.multiple_results:
if multiple_results:
return all_out_tangents
else:
out_tangent, = all_out_tangents
return out_tangent

if prim.multiple_results:
if multiple_results:
return out_primals, out_nzs, out_consts, linearized
else:
out_primal, = out_primals
out_nz, = out_nzs
return out_primal, out_nz, out_consts, linearized


class LinearizeTracer(Tracer):
__slots__ = ['primal', 'tangent']

Expand Down
29 changes: 17 additions & 12 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,18 +2112,23 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
def tangent_fun(consts_, *tangents):
tangents_nz = _filter_zeros(nzs, tangents)
assert len(consts_) == num_residuals
return pjit_p.bind(*(*tangents_nz, *consts_),
jaxpr=tangent_jaxpr,
in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings,
out_shardings=_filter_zeros(nzs_out, out_shardings),
in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts,
out_layouts=_filter_zeros(nzs_out, out_layouts),
resource_env=resource_env,
donated_invars=_filter_zeros(nzs, donated_invars) + res_donated,
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
nz_tangents_out = pjit_p.bind(*(*tangents_nz, *consts_),
jaxpr=tangent_jaxpr,
in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings,
out_shardings=_filter_zeros(nzs_out, out_shardings),
in_layouts=_filter_zeros(nzs, in_layouts) + res_layouts,
out_layouts=_filter_zeros(nzs_out, out_layouts),
resource_env=resource_env,
donated_invars=_filter_zeros(nzs, donated_invars) + res_donated,
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
tangent_avals_out = [v.aval.to_tangent_aval() for v in jaxpr.jaxpr.outvars]
nz_tangents_out_ = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_) if nz else ad.Zero(aval)
for (aval, nz) in zip(tangent_avals_out, nzs_out)]
return tangents_out

def _filter_zeros(is_nz_l, l):
return tuple(x for nz, x in zip(is_nz_l, l) if nz)
Expand Down

0 comments on commit f4e5f14

Please sign in to comment.