diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index c0ce8d09fb83..e8a56110fee3 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -495,15 +495,25 @@ def __call__(self, *args, **kwargs): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) + + propagate_user_sharding = None + infer_sharding_from_operands = None + sharding_rule = None + if config.use_shardy_partitioner.value: + sharding_rule = self.sharding_rule + else: + propagate_user_sharding = self.propagate_user_sharding + infer_sharding_from_operands = self.infer_sharding_from_operands + out_flat = custom_partitioning_p.bind( *consts, *args_flat, call=closed_call, partition=self.partition, - propagate_user_sharding=self.propagate_user_sharding, - infer_sharding_from_operands=self.infer_sharding_from_operands, + propagate_user_sharding=propagate_user_sharding, + infer_sharding_from_operands=infer_sharding_from_operands, decode_shardings=self.decode_shardings, - sharding_rule=self.sharding_rule, + sharding_rule=sharding_rule, in_tree=in_tree, out_tree=out_tree(), static_args=static_args