Skip to content

Commit

Permalink
#sdy dynamically choose which custom_partitioning API to use based …
Browse files Browse the repository at this point in the history
…on the current

value of the `use_shardy_partitioner` feature flag.

Before the way the API works depends on the value of the flag when the partitioning is defined. But we should allow this to be dynamically swapped in and out when the function is actually called. This change allows for that.

PiperOrigin-RevId: 715293018
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Jan 14, 2025
1 parent 4f2f5fa commit 74e912c
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,15 +495,25 @@ def __call__(self, *args, **kwargs):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())

propagate_user_sharding = None
infer_sharding_from_operands = None
sharding_rule = None
if config.use_shardy_partitioner.value:
sharding_rule = self.sharding_rule
else:
propagate_user_sharding = self.propagate_user_sharding
infer_sharding_from_operands = self.infer_sharding_from_operands

out_flat = custom_partitioning_p.bind(
*consts,
*args_flat,
call=closed_call,
partition=self.partition,
propagate_user_sharding=self.propagate_user_sharding,
infer_sharding_from_operands=self.infer_sharding_from_operands,
propagate_user_sharding=propagate_user_sharding,
infer_sharding_from_operands=infer_sharding_from_operands,
decode_shardings=self.decode_shardings,
sharding_rule=self.sharding_rule,
sharding_rule=sharding_rule,
in_tree=in_tree,
out_tree=out_tree(),
static_args=static_args
Expand Down

0 comments on commit 74e912c

Please sign in to comment.