diff --git a/jax/_src/api.py b/jax/_src/api.py index 02902125b3b5..690fb24d796f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2187,8 +2187,6 @@ def _infer_src_sharding(src, x) -> Sharding | None: return src # pytype: disable=bad-return-type if isinstance(x, array.ArrayImpl): return x.sharding - if config.sharding_in_types.value and hasattr(x, 'sharding'): - return x.sharding if isinstance(x, core.Tracer): val = x.to_concrete_value() if val is not None and isinstance(val, array.ArrayImpl): diff --git a/jax/_src/core.py b/jax/_src/core.py index 870ec57cd693..1250a3f4d954 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -273,10 +273,12 @@ def __init__(self, compute_type: str | None, threefry_partitionable: bool, xla_metadata=None): self.compute_type = compute_type self.threefry_partitionable = threefry_partitionable + self.cur_abstract_mesh = mesh_lib.get_abstract_mesh() self.xla_metadata = xla_metadata self._managers = [ (compute_on.extend_compute_type, self.compute_type), (config.threefry_partitionable.__call__, self.threefry_partitionable), + (mesh_lib.set_abstract_mesh, self.cur_abstract_mesh), (xla_metadata_lib.set_xla_metadata, self.xla_metadata), ] @@ -292,6 +294,7 @@ def __repr__(self): return ( f"JaxprEqnContext(compute_type={self.compute_type}, " f"threefry_partitionable={self.threefry_partitionable}, " + f"cur_abstract_mesh={self.cur_abstract_mesh}, " f"xla_metadata={self.xla_metadata})" ) @@ -535,6 +538,17 @@ def write(v: Var, val: Any) -> None: clean_up_dead_vars(eqn, env, lu) return map(read, jaxpr.outvars) +def check_avals_context_mesh(avals, prim_name): + if config.sharding_in_types.value: + for a in avals: + cur_mesh = mesh_lib.get_abstract_mesh() + if a.sharding.mesh != cur_mesh: + raise ValueError( + f"For primitive {prim_name}, context mesh {cur_mesh} should match" + f" the aval mesh {a.sharding.mesh} for shape {a.str_short()}. This" + " error occurs at source: " + f" {source_info_util.summarize(source_info_util.current())}") + # -------------------- tracing -------------------- @@ -1622,7 +1636,10 @@ def get_sharding(sharding, ndim): from jax._src.sharding_impls import NamedSharding # type: ignore if sharding is not None: - assert len(sharding.spec) == ndim + if len(sharding.spec) != ndim: + raise ValueError( + "Length of sharding.spec must be equal to aval's ndim. Got" + f" sharding.spec {sharding.spec} and aval.ndim {ndim}") return _maybe_modify_sharding(sharding) context_mesh = mesh_lib.get_abstract_mesh() @@ -2518,17 +2535,18 @@ def write(v: Var, a: AbstractValue) -> None: in_avals = [x.aval for x in in_atoms] # use in_atoms for dyn shapes # Compute the type of the primitive application. - if prim in custom_typechecks: - out_type, eqn_effects = custom_typechecks[prim]( - ctx_factory, *in_atoms, **eqn.params) - elif prim.call_primitive: - out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms, + with eqn.ctx.manager: + if prim in custom_typechecks: + out_type, eqn_effects = custom_typechecks[prim]( + ctx_factory, *in_atoms, **eqn.params) + elif prim.call_primitive: + out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms, + eqn.params) + elif prim.map_primitive: + out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals, eqn.params) - elif prim.map_primitive: - out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals, - eqn.params) - else: - out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params) + else: + out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params) # Check the computed effect type matches the eqn's annotation, and is # included in the jaxpr's annotation. diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 72e918241ea5..7b94d8be63be 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -522,8 +522,6 @@ def _batched_device_put_impl( device_put_p.def_impl(_batched_device_put_impl) def _device_put_abstract_eval(*xs, devices, srcs, copy_semantics): - if config.sharding_in_types.value: - return [x.update(sharding=s) for x, s in zip(xs, devices)] return xs device_put_p.def_abstract_eval(_device_put_abstract_eval) @@ -566,12 +564,6 @@ def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): # TODO(yashkatariya): Maybe we should add the custom calls anyways if it's # being used inside jit? Atleast for now, this preserves the old behavior. if ctx.module_context.all_default_mem_kind: - if config.sharding_in_types.value: - return [ - mlir.wrap_with_sharding_op( - ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto()) - for x, a in zip(xs, ctx.avals_out) - ] return xs def lower(x, device, aval, out_aval): if (isinstance(device, (Sharding, TransferToMemoryKind)) and @@ -597,12 +589,6 @@ def lower(x, device, aval, out_aval): def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics): - if config.sharding_in_types.value: - return [ - mlir.wrap_with_sharding_op( - ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto()) - for x, a in zip(xs, ctx.avals_out) - ] return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 0e3fdea02301..755780694d0e 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -49,7 +49,7 @@ from jax._src.interpreters import xla from jax._src.layout import AutoLayout, DeviceLocalLayout from jax._src.sharding import Sharding as JSharding -from jax._src.sharding_impls import AUTO +from jax._src.sharding_impls import AUTO, NamedSharding from jax._src.partition_spec import UnconstrainedSingleton from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension @@ -1055,6 +1055,17 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None: assert isinstance(s, JSharding) return s.memory_kind +def contains_unconstrained(s): + return isinstance(s, NamedSharding) and None in s._parsed_pspec + +def all_unconstrained(s): + return isinstance(s, NamedSharding) and all(p is None for p in s._parsed_pspec) + +def _get_unconstrained_dimensions(s): + us = contains_unconstrained(s) + return (us, all_unconstrained(s), + ({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None)) + def lower_jaxpr_to_module( module_name: str, @@ -1114,7 +1125,8 @@ def lower_jaxpr_to_module( f"only {platforms_with_donation} support donation") if (num_partitions > 1 and (result_shardings is None or - all(s is None or isinstance(s, AUTO) for s in result_shardings))): + all(s is None or isinstance(s, AUTO) or contains_unconstrained(s) + for s in result_shardings))): xla_donated_args = donated_args donated_args = [False] * len(donated_args) if xla_donated_args is None: @@ -1448,7 +1460,8 @@ def lower_jaxpr_to_fun( ir_arg_memory_kinds = None if arg_memory_kinds is not None: ir_arg_memory_kinds = util.flatten( - [[mk] * len_ir_types(types) for mk, types in zip(arg_memory_kinds, input_types)]) + [[mk] * len_ir_types(types) + for mk, types in zip(arg_memory_kinds, input_types)]) ir_arg_layouts = None if arg_layouts is not None: @@ -1459,13 +1472,18 @@ def lower_jaxpr_to_fun( ir_donated_args = None if xla_donated_args is not None: ir_donated_args = util.flatten( - [[is_donated] * len_ir_types(types) for is_donated, types in zip(xla_donated_args, input_types)]) + [[is_donated] * len_ir_types(types) + for is_donated, types in zip(xla_donated_args, input_types)]) ir_result_shardings = None + unconstrained_shardings = None if result_shardings is not None: ir_result_shardings = util.flatten( [[_to_physical_op_sharding(ctx, a, s)] * len_ir_types(types) for a, s, types in zip(output_avals, result_shardings, output_types)]) + unconstrained_shardings = util.flatten( + [[_get_unconstrained_dimensions(s)] * len_ir_types(types) + for s, types in zip(result_shardings, output_types)]) ir_result_memory_kinds = None custom_call_ir_result_memory_kinds = None @@ -1580,8 +1598,9 @@ def lower_jaxpr_to_fun( attrs['jax.result_info'] = ir.StringAttr.get(name_) if use_sharding_annotations and ir_result_shardings is not None: - for attrs, sharding in zip(result_attrs, ir_result_shardings): - if sharding is not None: + for attrs, sharding, us in zip(result_attrs, ir_result_shardings, + unconstrained_shardings): + if sharding is not None and not us[0]: if config.use_shardy_partitioner.value: attrs["sdy.sharding"] = get_sharding_attr(sharding) else: @@ -1658,6 +1677,15 @@ def lower_jaxpr_to_fun( o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s) for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)] + if ir_result_shardings is not None: + flat_outputs = [ + wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s, + unspecified_dims=us[2]) + if us[0] and not us[1] else o + for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings, + output_avals, unconstrained_shardings) + ] + # Insert a custom call if output is on host because XLA needs that to do the # transfer. if custom_call_ir_result_memory_kinds is not None and name == "main": diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 840b9e522968..f68dc85a4ec9 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2162,8 +2162,7 @@ def _abstract_to_concrete_mesh(abstract_mesh): out = [] for s, a in zip(shardings, avals): - if (isinstance(s, UnspecifiedValue) and a.sharding is not None and - all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)): + if isinstance(s, UnspecifiedValue) and a.sharding is not None: out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh), a.sharding.spec)) else: @@ -2792,6 +2791,11 @@ def _maybe_get_and_check_out_shardings( dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) new_out_shardings.append(xla_s) + elif mlir.contains_unconstrained(orig): + if (aval is not core.abstract_token and + dtypes.issubdtype(aval.dtype, dtypes.extended)): + xla_s = sharding_impls.logical_sharding(aval, xla_s) + new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # pytype: disable=attribute-error @@ -2907,8 +2911,9 @@ def from_hlo(name: str, allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO)) for i in in_shardings) - allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO)) - for o in out_shardings) + allow_prop_to_outputs = tuple( + isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o) + for o in out_shardings) mesh = None if auto_spmd_lowering: diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 46fd82533038..012e0b8331a1 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -687,6 +687,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") if config.sharding_in_types.value: + core.check_avals_context_mesh(args, 'all_reduce') out_avals = [ ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes)) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 78d125436029..1bdf1f16ae5f 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -53,6 +53,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, weak_type = weak_type_rule(*avals, **kwargs) least_specialized = type(max(avals, key=_get_array_abstraction_level)) if least_specialized is core.ShapedArray: + core.check_avals_context_mesh(avals, prim.name) return core.ShapedArray( shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), weak_type=weak_type, @@ -78,6 +79,7 @@ def standard_multi_result_abstract_eval( if least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) + core.check_avals_context_mesh(avals, prim.name) out_shardings = (sharding_rule(*avals, **kwargs) if config.sharding_in_types.value else [None] * len(out_shapes)) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 2351266c6c24..25fb2b38f7fa 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -350,6 +350,25 @@ def local_devices(self): def abstract_mesh(self): return AbstractMesh(self.shape_tuple, axis_types=self.axis_types) + def with_axis_types(self, new_axis_types) -> Mesh: + return Mesh(self.devices, self.axis_names, axis_types=new_axis_types) + + @functools.cached_property + def _are_all_axes_collective(self) -> bool: + return all(t == AxisTypes.Collective for t in self.axis_types.keys()) + + @functools.cached_property + def _are_all_axes_auto(self) -> bool: + return all(t == AxisTypes.Auto for t in self.axis_types.keys()) + + @functools.cached_property + def _any_axis_collective(self) -> bool: + return any(t == AxisTypes.Collective for t in self.axis_types.keys()) + + @functools.cached_property + def _any_axis_auto(self) -> bool: + return any(t == AxisTypes.Auto for t in self.axis_types.keys()) + EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -396,8 +415,9 @@ def __eq__(self, other): self._axis_types_tuple == other._axis_types_tuple) def __repr__(self): + mesh_repr = ", ".join(f"'{n}': {v}" for n, v in self.shape_tuple) atr = f", axis_types={self.axis_types}" - return f"AbstractMesh({self.shape_tuple}{atr})" + return f"AbstractMesh({mesh_repr}{atr})" @property def axis_names(self): @@ -427,6 +447,9 @@ def _internal_device_list(self): def empty(self): return self.size == 0 + def with_axis_types(self, new_axis_types) -> AbstractMesh: + return AbstractMesh(self.shape_tuple, axis_types=new_axis_types) + @functools.cached_property def _are_all_axes_collective(self) -> bool: return all(t == AxisTypes.Collective for t in self.axis_types.keys()) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 3cb63749237c..405fadc2f896 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -61,8 +61,8 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc -from jax._src import sharding from jax._src.mesh import AbstractMesh +from jax._src.sharding import Sharding from jax._src.sharding_impls import ( NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, @@ -73,7 +73,7 @@ from jax._src.tree_util import ( tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves, treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr, - PyTreeDef, none_leaf_registry as none_lr) + PyTreeDef, none_leaf_registry as none_lr, tree_map) from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, split_list, weakref_lru_cache, @@ -1027,7 +1027,7 @@ def hashable_pytree(pytree): def _create_sharding_for_array(mesh, x, name, api_name): if x is None and (mesh is None or mesh.empty): return UNSPECIFIED - if isinstance(x, (AUTO, UnspecifiedValue, sharding.Sharding)): + if isinstance(x, (AUTO, UnspecifiedValue, Sharding)): return x if mesh is None: msg = ('jax.jit only supports `Sharding`s being passed to' @@ -1339,7 +1339,7 @@ def _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set): orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) - if isinstance(orig_out_shardings, (UnspecifiedValue, sharding.Sharding)): + if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)): out_shardings_flat = (orig_out_shardings,) * len(out_avals) else: out_shardings_flat = flatten_axis_resources( @@ -1571,7 +1571,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] else: resolved_in_shardings.append(arg_s) else: - assert isinstance(arg_s, sharding.Sharding) + assert isinstance(arg_s, Sharding) if dispatch.is_single_device_sharding(arg_s): resolved_in_shardings.append(UNSPECIFIED) else: @@ -1903,7 +1903,7 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params): core.custom_typechecks[pjit_p] = _pjit_typecheck -def _pjit_abstract_eval(*args, jaxpr, **_): +def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_): return jaxpr.out_avals, jaxpr.effects pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval) @@ -2016,7 +2016,7 @@ def _pjit_batcher(axis_data, vals_in, dims_in, batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule def _pjit_batcher_for_sharding( - s: sharding.Sharding | UnspecifiedValue, + s: Sharding | UnspecifiedValue, dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int): if isinstance(s, UnspecifiedValue): return s @@ -2673,6 +2673,67 @@ def _sharding_constraint_batcher( batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher batching.skippable_batchers[sharding_constraint_p] = lambda _: () +# -------------------- sharding_cast --------------------------- + +def sharding_cast(xs, shardings): + if isinstance(shardings, NamedSharding): + return tree_map(lambda x: sharding_cast_p.bind( + x, src_sharding=x.sharding, dst_sharding=shardings), xs) + + x_flat, treedef = tree_flatten(xs) + shardings_flat = flatten_axes("sharding_cast shardings", treedef, shardings) + out_flat = [sharding_cast_p.bind(x, src_sharding=x.sharding, dst_sharding=s) + for x, s in safe_zip(x_flat, shardings_flat)] + return tree_unflatten(treedef, out_flat) + +sharding_cast_p = core.Primitive('sharding_cast') +def _sharding_cast_abstract_eval(aval, src_sharding, dst_sharding): + if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple: + raise ValueError( + f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not' + ' match the mesh shape of the target sharding' + f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}') + return aval.update(sharding=dst_sharding) +sharding_cast_p.def_abstract_eval(_sharding_cast_abstract_eval) + +def _sharding_cast_impl(x, src_sharding, dst_sharding): + return dispatch.apply_primitive(sharding_cast_p, x, src_sharding=src_sharding, + dst_sharding=dst_sharding) +sharding_cast_p.def_impl(_sharding_cast_impl) + +def _sharding_cast_transpose_rule(ct, _, src_sharding, dst_sharding): + return [sharding_cast_p.bind(ct, src_sharding=dst_sharding, + dst_sharding=src_sharding)] +ad.deflinear2(sharding_cast_p, _sharding_cast_transpose_rule) + +def _sharding_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding): + aval, = ctx.avals_in + aval_out, = ctx.avals_out + proto = dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto() + return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)] +mlir.register_lowering(sharding_cast_p, _sharding_cast_hlo_lowering) + +# TODO(yashkatariya): Comment this in after vmap ShiT tests are added. +# def _sharding_cast_batcher(axis_data, vals_in, dims_in, src_sharding, +# dst_sharding): +# if axis_data.spmd_name is not None: +# used = {n for ns in dst_sharding.spec +# for n in (ns if isinstance(ns, tuple) else (ns,))} +# if set(axis_data.spmd_name) & used: +# raise ValueError( +# f'vmap spmd_axis_name {axis_data.spmd_name} cannot ' +# f'appear in sharding_cast spec, but got spec {dst_sharding.spec}') +# x, = vals_in +# d, = dims_in + +# val = None if axis_data.spmd_name is None else axis_data.spmd_name +# new_spec = PartitionSpec(*util.tuple_insert(dst_sharding.spec, d, val)) +# vmapped_dst_sharding = NamedSharding(dst_sharding.mesh, new_spec) +# y = sharding_cast_p.bind(x, src_sharding=src_sharding, +# dst_sharding=vmapped_dst_sharding) +# return y, d +# batching.fancy_primitive_batchers[sharding_cast_p] = _sharding_cast_batcher +# batching.skippable_batchers[sharding_cast_p] = lambda _: () # -------------------- helpers -------------------- diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index ad3190a6d481..69e2adc4dc77 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -69,8 +69,6 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes): @util.cache(max_size=128, trace_context_in_key=False) def _check_axis_type_consistency(mesh, parsed_pspec): - if mesh.axis_types is None: - return for p in parsed_pspec: if p is not None: if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p): @@ -78,6 +76,11 @@ def _check_axis_type_consistency(mesh, parsed_pspec): 'AxisTypes should be the same in a tuple subset of PartitionSpec:' f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis' f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})') + if mesh_lib.AxisTypes.Auto not in mesh.axis_types and None in parsed_pspec: + raise ValueError( + f'PartitionSpec {parsed_pspec.get_partition_spec()} cannot contain' + ' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh' + f' axis_types: {mesh.axis_types}') def hashed_index(x) -> int: @@ -271,11 +274,15 @@ def __init__( self._parsed_pspec = preprocess(self.mesh, self.spec, _parsed_pspec) def __repr__(self): - mesh_repr = ", ".join(f"'{k}': {v}" for k, v in self.mesh.shape.items()) mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}' ldi = ('' if self._logical_device_ids is None else f', logical_device_ids={self._logical_device_ids}') - return f'NamedSharding(mesh=Mesh({mesh_repr}), spec={self.spec}{mem}{ldi})' + if isinstance(self.mesh, mesh_lib.AbstractMesh): + mesh_repr = f"{self.mesh}" + else: + nv_str = ", ".join(f"'{n}': {v}" for n, v in self.mesh.shape.items()) + mesh_repr = f"Mesh({nv_str})" + return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})' def __reduce__(self): return (type(self), (self.mesh, self.spec), @@ -381,6 +388,9 @@ def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding: spec = PartitionSpec(*spec) return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind) + def with_mesh(self, new_mesh: mesh_lib.Mesh) -> NamedSharding: + return NamedSharding(new_mesh, self.spec, memory_kind=self.memory_kind) + def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index d2233362ae90..332074f919dd 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -172,6 +172,8 @@ def test_primitive_coverage(self): continue if p.name == "sharding_constraint": continue + if p.name == "sharding_cast": + continue # TODO: Remove once tensorflow is 2.10.0 everywhere. if p.name == "optimization_barrier": continue diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8fc52df012d1..95f9712fa98e 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -52,7 +52,7 @@ from jax._src.sharding_impls import ( AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, SingleDeviceSharding, parse_flatten_op_sharding) -from jax._src.pjit import pjit +from jax._src.pjit import pjit, sharding_cast from jax._src import mesh as mesh_lib from jax._src.interpreters import pxla from jax._src.lib.mlir import dialects @@ -4680,6 +4680,31 @@ def g(x, y): RuntimeError, 'A jitted computation cannot contain AbstractMesh'): lowered3.compile() + def test_jit_out_shardings_unconstrained(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, s) + + out_s = NamedSharding(mesh, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + @partial(jax.jit, out_shardings=out_s) + def f(x): + return x * 2 + + out = f(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, np_inp * 2) + + @partial(jax.jit, out_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'y'))) + def g(x): + return x * 3 + + out = g(arr) + self.assertArraysEqual(out, np_inp * 3) + self.assertEqual(out.sharding, s) + lowered_text = g.lower(arr).as_text() + self.assertIn("unspecified_dims=[0]", lowered_text) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") @@ -5040,7 +5065,7 @@ def f(x, y): return x + y with self.assertRaisesRegex( - ValueError, "Mesh for all inputs should be equal"): + ValueError, "For primitive add, context mesh.*aval mesh"): f(arr1, arr2) @jtu.with_user_mesh((2, 2), ('x', 'y')) @@ -5264,14 +5289,14 @@ def f(pred, on_true, on_false): f(arr1 == arr2, arr1, arr3) @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_device_put_reshard(self, mesh): + def test_sharding_cast_reshard(self, mesh): np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @jax.jit def f(x): - y = jax.device_put(x, NamedSharding(x.sharding.mesh, P('x', None))) + y = sharding_cast(x, NamedSharding(x.sharding.mesh, P('x', None))) self.assertEqual(y.sharding.spec, P('x', None)) return y @@ -5548,7 +5573,7 @@ def f(x, x2): return a out = f(arr, arr.T) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x',))) def test_auto_user(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'), @@ -5607,6 +5632,131 @@ def test_where_with_scalar(self, mesh): self.assertArraysEqual(out, x) self.assertEqual(out.sharding, s) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_full_user_to_full_auto(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = x * 2 + auto_mesh = mesh_lib.get_abstract_mesh().with_axis_types( + {mesh_lib.AxisTypes.Auto: ('x', 'y')}) + y = sharding_cast(y, y.sharding.with_mesh(auto_mesh)) + with mesh_lib.set_abstract_mesh(auto_mesh): + self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + z = jnp.sin(y) + self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + a = z @ z.T + self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + a = sharding_cast( + a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None))) + self.assertEqual(a.sharding.spec, P('x', None)) + return a + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + jaxpr = f.trace(arr).jaxpr + out2 = core.jaxpr_as_fun(jaxpr)(arr) + self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_user_mesh((2, 2), ('x', 'y'), + axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')}) + def test_full_auto_to_full_user(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = x * 2 + user_mesh = mesh_lib.get_abstract_mesh().with_axis_types( + {mesh_lib.AxisTypes.User: ('x', 'y')}) + y = sharding_cast(y, NamedSharding(user_mesh, P(None, 'y'))) + with mesh_lib.set_abstract_mesh(user_mesh): + self.assertEqual(y.sharding.spec, P(None, 'y')) + z = jnp.sin(y) + self.assertEqual(z.sharding.spec, P(None, 'y')) + a = z @ z.T + self.assertEqual(a.sharding.spec, P(None, None)) + a = sharding_cast( + a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None))) + self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, None)) + return a + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + + jaxpr = f.trace(arr).jaxpr + core.jaxpr_as_fun(jaxpr)(arr) # doesn't crash + + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_full_user_to_auto_user_mix(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + y = x * 2 + mix_mesh = mesh_lib.get_abstract_mesh().with_axis_types( + {mesh_lib.AxisTypes.Auto: 'x', mesh_lib.AxisTypes.User: 'y'}) + y = sharding_cast(y, y.sharding.with_mesh(mix_mesh)) + with mesh_lib.set_abstract_mesh(mix_mesh): + self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, 'y')) + z = jnp.sin(y) + self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, 'y')) + a = z @ z.T + self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED)) + a = sharding_cast( + a, NamedSharding(mesh_lib.get_abstract_mesh(), P('x', None))) + self.assertEqual(a.sharding.spec, P('x', None)) + return a + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + jaxpr = f.trace(arr).jaxpr + out2 = core.jaxpr_as_fun(jaxpr)(arr) + self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', None))) + + @jtu.with_user_mesh((2, 1), ('x', 'y')) + def test_user_auto_mix_error(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x, y): + x = x * 2 + mix_mesh = mesh_lib.get_abstract_mesh().with_axis_types( + {mesh_lib.AxisTypes.Auto: 'x', mesh_lib.AxisTypes.User: 'y'}) + with mesh_lib.set_abstract_mesh(mix_mesh): + z = x @ y + return z + + with self.assertRaisesRegex( + ValueError, "For primitive dot_general, context mesh.*aval mesh"): + f(arr, arr.T) + + def test_sharding_cast_src_dst_mesh_mismatch(self): + np_inp = np.arange(16.).reshape(8, 2) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + mesh2 = jtu.create_mesh((2, 1), ('a', 'b')) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + f = lambda x: sharding_cast(x, NamedSharding(mesh2, P('a', 'b'))) + with self.assertRaisesRegex( + ValueError, "Mesh shape of the input.*does not match"): + f(arr) + + with mesh_lib.set_mesh(mesh): + with self.assertRaisesRegex( + ValueError, "Mesh shape of the input.*does not match"): + jax.jit(f)(arr) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):