Skip to content

Commit

Permalink
Allow P.UNCONSTRAINED in out_shardings at top level jit. This is re…
Browse files Browse the repository at this point in the history
…quired for sharding in types to work properly when out_avals contain UNCONSTRAINED specs.

This also simplifies the `impl` rule of `sharding_cast`.

PiperOrigin-RevId: 706966253
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Dec 17, 2024
1 parent 7dd401c commit 30726bc
Show file tree
Hide file tree
Showing 12 changed files with 338 additions and 54 deletions.
2 changes: 0 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,8 +2187,6 @@ def _infer_src_sharding(src, x) -> Sharding | None:
return src # pytype: disable=bad-return-type
if isinstance(x, array.ArrayImpl):
return x.sharding
if config.sharding_in_types.value and hasattr(x, 'sharding'):
return x.sharding
if isinstance(x, core.Tracer):
val = x.to_concrete_value()
if val is not None and isinstance(val, array.ArrayImpl):
Expand Down
40 changes: 29 additions & 11 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,12 @@ def __init__(self, compute_type: str | None, threefry_partitionable: bool,
xla_metadata=None):
self.compute_type = compute_type
self.threefry_partitionable = threefry_partitionable
self.cur_abstract_mesh = mesh_lib.get_abstract_mesh()
self.xla_metadata = xla_metadata
self._managers = [
(compute_on.extend_compute_type, self.compute_type),
(config.threefry_partitionable.__call__, self.threefry_partitionable),
(mesh_lib.set_abstract_mesh, self.cur_abstract_mesh),
(xla_metadata_lib.set_xla_metadata, self.xla_metadata),
]

Expand All @@ -292,6 +294,7 @@ def __repr__(self):
return (
f"JaxprEqnContext(compute_type={self.compute_type}, "
f"threefry_partitionable={self.threefry_partitionable}, "
f"cur_abstract_mesh={self.cur_abstract_mesh}, "
f"xla_metadata={self.xla_metadata})"
)

Expand Down Expand Up @@ -535,6 +538,17 @@ def write(v: Var, val: Any) -> None:
clean_up_dead_vars(eqn, env, lu)
return map(read, jaxpr.outvars)

def check_avals_context_mesh(avals, prim_name):
if config.sharding_in_types.value:
for a in avals:
cur_mesh = mesh_lib.get_abstract_mesh()
if a.sharding.mesh != cur_mesh:
raise ValueError(
f"For primitive {prim_name}, context mesh {cur_mesh} should match"
f" the aval mesh {a.sharding.mesh} for shape {a.str_short()}. This"
" error occurs at source: "
f" {source_info_util.summarize(source_info_util.current())}")


# -------------------- tracing --------------------

Expand Down Expand Up @@ -1622,7 +1636,10 @@ def get_sharding(sharding, ndim):
from jax._src.sharding_impls import NamedSharding # type: ignore

if sharding is not None:
assert len(sharding.spec) == ndim
if len(sharding.spec) != ndim:
raise ValueError(
"Length of sharding.spec must be equal to aval's ndim. Got"
f" sharding.spec {sharding.spec} and aval.ndim {ndim}")
return _maybe_modify_sharding(sharding)

context_mesh = mesh_lib.get_abstract_mesh()
Expand Down Expand Up @@ -2518,17 +2535,18 @@ def write(v: Var, a: AbstractValue) -> None:
in_avals = [x.aval for x in in_atoms] # use in_atoms for dyn shapes

# Compute the type of the primitive application.
if prim in custom_typechecks:
out_type, eqn_effects = custom_typechecks[prim](
ctx_factory, *in_atoms, **eqn.params)
elif prim.call_primitive:
out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms,
with eqn.ctx.manager:
if prim in custom_typechecks:
out_type, eqn_effects = custom_typechecks[prim](
ctx_factory, *in_atoms, **eqn.params)
elif prim.call_primitive:
out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms,
eqn.params)
elif prim.map_primitive:
out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals,
eqn.params)
elif prim.map_primitive:
out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals,
eqn.params)
else:
out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params)
else:
out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params)

# Check the computed effect type matches the eqn's annotation, and is
# included in the jaxpr's annotation.
Expand Down
14 changes: 0 additions & 14 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,6 @@ def _batched_device_put_impl(
device_put_p.def_impl(_batched_device_put_impl)

def _device_put_abstract_eval(*xs, devices, srcs, copy_semantics):
if config.sharding_in_types.value:
return [x.update(sharding=s) for x, s in zip(xs, devices)]
return xs
device_put_p.def_abstract_eval(_device_put_abstract_eval)

Expand Down Expand Up @@ -566,12 +564,6 @@ def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
# TODO(yashkatariya): Maybe we should add the custom calls anyways if it's
# being used inside jit? Atleast for now, this preserves the old behavior.
if ctx.module_context.all_default_mem_kind:
if config.sharding_in_types.value:
return [
mlir.wrap_with_sharding_op(
ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto())
for x, a in zip(xs, ctx.avals_out)
]
return xs
def lower(x, device, aval, out_aval):
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
Expand All @@ -597,12 +589,6 @@ def lower(x, device, aval, out_aval):


def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
if config.sharding_in_types.value:
return [
mlir.wrap_with_sharding_op(
ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto())
for x, a in zip(xs, ctx.avals_out)
]
return xs
mlir.register_lowering(device_put_p, _common_device_put_lowering)

Expand Down
40 changes: 34 additions & 6 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import AUTO
from jax._src.sharding_impls import AUTO, NamedSharding
from jax._src.partition_spec import UnconstrainedSingleton
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
Expand Down Expand Up @@ -1055,6 +1055,17 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
assert isinstance(s, JSharding)
return s.memory_kind

def contains_unconstrained(s):
return isinstance(s, NamedSharding) and None in s._parsed_pspec

def all_unconstrained(s):
return isinstance(s, NamedSharding) and all(p is None for p in s._parsed_pspec)

def _get_unconstrained_dimensions(s):
us = contains_unconstrained(s)
return (us, all_unconstrained(s),
({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None))


def lower_jaxpr_to_module(
module_name: str,
Expand Down Expand Up @@ -1114,7 +1125,8 @@ def lower_jaxpr_to_module(
f"only {platforms_with_donation} support donation")
if (num_partitions > 1 and
(result_shardings is None or
all(s is None or isinstance(s, AUTO) for s in result_shardings))):
all(s is None or isinstance(s, AUTO) or contains_unconstrained(s)
for s in result_shardings))):
xla_donated_args = donated_args
donated_args = [False] * len(donated_args)
if xla_donated_args is None:
Expand Down Expand Up @@ -1448,7 +1460,8 @@ def lower_jaxpr_to_fun(
ir_arg_memory_kinds = None
if arg_memory_kinds is not None:
ir_arg_memory_kinds = util.flatten(
[[mk] * len_ir_types(types) for mk, types in zip(arg_memory_kinds, input_types)])
[[mk] * len_ir_types(types)
for mk, types in zip(arg_memory_kinds, input_types)])

ir_arg_layouts = None
if arg_layouts is not None:
Expand All @@ -1459,13 +1472,18 @@ def lower_jaxpr_to_fun(
ir_donated_args = None
if xla_donated_args is not None:
ir_donated_args = util.flatten(
[[is_donated] * len_ir_types(types) for is_donated, types in zip(xla_donated_args, input_types)])
[[is_donated] * len_ir_types(types)
for is_donated, types in zip(xla_donated_args, input_types)])

ir_result_shardings = None
unconstrained_shardings = None
if result_shardings is not None:
ir_result_shardings = util.flatten(
[[_to_physical_op_sharding(ctx, a, s)] * len_ir_types(types)
for a, s, types in zip(output_avals, result_shardings, output_types)])
unconstrained_shardings = util.flatten(
[[_get_unconstrained_dimensions(s)] * len_ir_types(types)
for s, types in zip(result_shardings, output_types)])

ir_result_memory_kinds = None
custom_call_ir_result_memory_kinds = None
Expand Down Expand Up @@ -1580,8 +1598,9 @@ def lower_jaxpr_to_fun(
attrs['jax.result_info'] = ir.StringAttr.get(name_)

if use_sharding_annotations and ir_result_shardings is not None:
for attrs, sharding in zip(result_attrs, ir_result_shardings):
if sharding is not None:
for attrs, sharding, us in zip(result_attrs, ir_result_shardings,
unconstrained_shardings):
if sharding is not None and not us[0]:
if config.use_shardy_partitioner.value:
attrs["sdy.sharding"] = get_sharding_attr(sharding)
else:
Expand Down Expand Up @@ -1658,6 +1677,15 @@ def lower_jaxpr_to_fun(
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]

if ir_result_shardings is not None:
flat_outputs = [
wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s,
unspecified_dims=us[2])
if us[0] and not us[1] else o
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
output_avals, unconstrained_shardings)
]

# Insert a custom call if output is on host because XLA needs that to do the
# transfer.
if custom_call_ir_result_memory_kinds is not None and name == "main":
Expand Down
13 changes: 9 additions & 4 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,8 +2162,7 @@ def _abstract_to_concrete_mesh(abstract_mesh):

out = []
for s, a in zip(shardings, avals):
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
a.sharding.spec))
else:
Expand Down Expand Up @@ -2792,6 +2791,11 @@ def _maybe_get_and_check_out_shardings(
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
new_out_shardings.append(xla_s)
elif mlir.contains_unconstrained(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # pytype: disable=attribute-error
Expand Down Expand Up @@ -2907,8 +2911,9 @@ def from_hlo(name: str,

allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
for i in in_shardings)
allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO))
for o in out_shardings)
allow_prop_to_outputs = tuple(
isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o)
for o in out_shardings)

mesh = None
if auto_spmd_lowering:
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
raise ValueError(f"axis_index_groups can only be used with reductions over "
f"named axes, but got: {axes}")
if config.sharding_in_types.value:
core.check_avals_context_mesh(args, 'all_reduce')
out_avals = [
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype,
sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes))
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
weak_type = weak_type_rule(*avals, **kwargs)
least_specialized = type(max(avals, key=_get_array_abstraction_level))
if least_specialized is core.ShapedArray:
core.check_avals_context_mesh(avals, prim.name)
return core.ShapedArray(
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
weak_type=weak_type,
Expand All @@ -78,6 +79,7 @@ def standard_multi_result_abstract_eval(
if least_specialized is core.ShapedArray:
out_shapes = shape_rule(*avals, **kwargs)
out_dtypes = dtype_rule(*avals, **kwargs)
core.check_avals_context_mesh(avals, prim.name)
out_shardings = (sharding_rule(*avals, **kwargs)
if config.sharding_in_types.value else
[None] * len(out_shapes))
Expand Down
25 changes: 24 additions & 1 deletion jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,25 @@ def local_devices(self):
def abstract_mesh(self):
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)

def with_axis_types(self, new_axis_types) -> Mesh:
return Mesh(self.devices, self.axis_names, axis_types=new_axis_types)

@functools.cached_property
def _are_all_axes_collective(self) -> bool:
return all(t == AxisTypes.Collective for t in self.axis_types.keys())

@functools.cached_property
def _are_all_axes_auto(self) -> bool:
return all(t == AxisTypes.Auto for t in self.axis_types.keys())

@functools.cached_property
def _any_axis_collective(self) -> bool:
return any(t == AxisTypes.Collective for t in self.axis_types.keys())

@functools.cached_property
def _any_axis_auto(self) -> bool:
return any(t == AxisTypes.Auto for t in self.axis_types.keys())


EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))

Expand Down Expand Up @@ -396,8 +415,9 @@ def __eq__(self, other):
self._axis_types_tuple == other._axis_types_tuple)

def __repr__(self):
mesh_repr = ", ".join(f"'{n}': {v}" for n, v in self.shape_tuple)
atr = f", axis_types={self.axis_types}"
return f"AbstractMesh({self.shape_tuple}{atr})"
return f"AbstractMesh({mesh_repr}{atr})"

@property
def axis_names(self):
Expand Down Expand Up @@ -427,6 +447,9 @@ def _internal_device_list(self):
def empty(self):
return self.size == 0

def with_axis_types(self, new_axis_types) -> AbstractMesh:
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)

@functools.cached_property
def _are_all_axes_collective(self) -> bool:
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
Expand Down
Loading

0 comments on commit 30726bc

Please sign in to comment.