diff --git a/jax/_src/api.py b/jax/_src/api.py index 5585544d80d7..c1bb9ff72968 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -67,8 +67,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib from jax._src.sharding import Sharding -from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind, - NamedSharding) +from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind from jax._src.layout import Layout, AutoLayout from jax._src.traceback_util import api_boundary from jax._src import tree_util @@ -2564,11 +2563,7 @@ def _sds_aval_mapping(x): aval = ShapedArray( x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), weak_type=x.weak_type) - if config.sharding_in_types.value and isinstance(x.sharding, NamedSharding): - return aval.update(sharding=NamedSharding( - x.sharding.mesh.abstract_mesh, - x.sharding.spec._normalized_spec(x.ndim))) - return aval + return core.update_aval_with_sharding(aval, x.sharding) core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping diff --git a/jax/_src/array.py b/jax/_src/array.py index 802a3f5f6a5d..6d3d311bbb7e 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -41,7 +41,7 @@ from jax._src.lib import xla_extension as xe from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, NamedSharding, + PmapSharding, SingleDeviceSharding, device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache @@ -753,7 +753,8 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: first_value = per_device_values[0] expected_dtype = first_value.dtype expected_shape = sharding.shard_shape(shape) - aval = core.ShapedArray(shape, expected_dtype) + aval = core.update_aval_with_sharding( + core.ShapedArray(shape, expected_dtype), sharding) _validate_shape_and_dtype_for_per_device_arrays( per_device_values, expected_shape=expected_shape, @@ -1017,7 +1018,8 @@ def make_array_from_single_device_arrays( raise ValueError( "jax.make_array_from_single_device_arrays requires a list of concrete" f" arrays as input. got types {set(map(type, arrays))}") - aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False) + aval = core.update_aval_with_sharding( + core.ShapedArray(shape, arrays[0].dtype, weak_type=False), sharding) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) @@ -1028,13 +1030,7 @@ def make_array_from_single_device_arrays( xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity def _get_aval_array(self): - if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): - return self.aval.update(sharding=NamedSharding( - self.sharding.mesh.abstract_mesh, - self.sharding.spec._normalized_spec(self.ndim))) - else: - return self.aval - + return core.update_aval_with_sharding(self.aval, self.sharding) core.pytype_aval_mappings[ArrayImpl] = _get_aval_array # TODO(jakevdp) replace this with true inheritance at the C++ level. @@ -1179,6 +1175,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): def _array_global_result_handler(global_aval, out_sharding, committed): + global_aval = core.update_aval_with_sharding(global_aval, out_sharding) if global_aval.dtype == dtypes.float0: return lambda _: np.zeros(global_aval.shape, dtypes.float0) if dtypes.issubdtype(global_aval.dtype, dtypes.extended): diff --git a/jax/_src/core.py b/jax/_src/core.py index df061d5f8b8f..9deac2b52211 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -816,6 +816,12 @@ def __getattr__(self, name): # if the aval property raises an AttributeError, gets caught here assert not config.enable_checks.value or name != "aval" + if name == 'sharding': + raise AttributeError( + self, + f"The 'sharding' attribute is not available on {self._error_repr()}." + f"{self._origin_msg()}") + try: attr = getattr(self.aval, name) except AttributeError as err: @@ -1421,6 +1427,13 @@ def check_valid_jaxtype(x): raise TypeError( f"Value {x!r} of type {type(x)} is not a valid JAX type") +def update_aval_with_sharding(aval, sharding): + from jax._src.sharding_impls import NamedSharding # type: ignore + if config.sharding_in_types.value and isinstance(sharding, NamedSharding): + aval = aval.update(sharding=NamedSharding( + sharding.mesh.abstract_mesh, sharding.spec._normalized_spec(aval.ndim))) + return aval + # We have three flavors of abstractification APIs here which each used to have # their own separate implementation. Now they're effectively the same, with the @@ -1433,8 +1446,6 @@ def check_valid_jaxtype(x): # TODO(jakevdp): can these be unified further? def shaped_abstractify(x): - from jax._src.sharding_impls import NamedSharding # type: ignore - typ = type(x) if (aval_fn := pytype_aval_mappings.get(typ)): # fast path return aval_fn(x) @@ -1448,12 +1459,7 @@ def shaped_abstractify(x): if hasattr(x, 'dtype'): aval = ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False)) - if (config.sharding_in_types.value and hasattr(x, 'sharding') and - isinstance(x.sharding, NamedSharding)): - return aval.update(sharding=NamedSharding( - x.sharding.mesh.abstract_mesh, - x.sharding.spec._normalized_spec(aval.ndim))) - return aval + return update_aval_with_sharding(aval, getattr(x, 'sharding', None)) raise TypeError( f"Cannot interpret value of type {typ} as an abstract array; it " "does not have a dtype attribute") @@ -1701,13 +1707,17 @@ def get_sharding(sharding, ndim): raise ValueError( "Length of sharding.spec must be equal to aval's ndim. Got" f" sharding.spec {sharding.spec} and aval.ndim {ndim}") - return _maybe_modify_sharding(sharding) - - context_mesh = mesh_lib.get_abstract_mesh() - if not context_mesh: - raise RuntimeError("Please set the mesh via `jax.set_mesh` API.") - assert sharding is None - return NamedSharding(context_mesh, P(*[None] * ndim)) + out_s = _maybe_modify_sharding(sharding) + else: + context_mesh = mesh_lib.get_abstract_mesh() + if not context_mesh: + raise RuntimeError("Please set the mesh via `jax.set_mesh` API.") + assert sharding is None + out_s = NamedSharding(context_mesh, P(*[None] * ndim)) + if not isinstance(out_s.mesh, mesh_lib.AbstractMesh): + raise ValueError("Mesh of an aval must be an AbstractMesh. " + f"Got {out_s.mesh} of type {type(out_s.mesh)}") + return out_s class ShapedArray(UnshapedArray): @@ -1720,9 +1730,6 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None): self.weak_type = weak_type if config.sharding_in_types.value: self.sharding = get_sharding(sharding, len(self.shape)) - if not isinstance(self.sharding.mesh, mesh_lib.AbstractMesh): - raise ValueError( - f"Mesh of an aval must be an AbstractMesh. Got {self.sharding.mesh}") def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1796,14 +1803,6 @@ def _get_shape_sharding_str(shape, spec): out.append(f"{s1}@{s2}") return ','.join(out) -def _get_abstract_sharding(val): - from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error - - if (config.sharding_in_types.value and hasattr(val, 'sharding') and - isinstance(val.sharding, NamedSharding)): - return NamedSharding(val.sharding.mesh.abstract_mesh, - val.sharding.spec._normalized_spec(val.ndim)) - return None def primal_dtype_to_tangent_dtype(primal_dtype): if isinstance(primal_dtype, dtypes.ExtendedDType): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 8e18d14559b8..7e54b68defc3 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -230,9 +230,9 @@ def scan(f, init, xs, length=None): if not hasattr(x, 'shape')))) from err if (config.sharding_in_types.value and - not all(x.sharding.spec[0] is None for x in xs_flat)): + not all(x.aval.sharding.spec[0] is None for x in xs_flat)): raise ValueError('0th dimension of all xs should be replicated. Got ' - f'{", ".join(str(x.sharding.spec) for x in xs_flat)}') + f'{", ".join(str(x.aval.sharding.spec) for x in xs_flat)}') if length is not None: try: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 580154d012ac..427d51e98cb2 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -586,7 +586,7 @@ def _convert_element_type( if (config.sharding_in_types.value and sharding is None and isinstance(operand, Array)): - sharding = operand.sharding + sharding = operand.aval.sharding sharding = canonicalize_sharding(sharding, check_mesh_consistency=False) # type: ignore @@ -1920,7 +1920,8 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) fill_value = _convert_element_type(fill_value, dtype, weak_type) if (sharding is not None and not isinstance(sharding, PmapSharding) and - isinstance(fill_value, array.ArrayImpl)): + isinstance(fill_value, array.ArrayImpl) and + not config.sharding_in_types.value): broadcast_shape = sharding.shard_shape(shape) shard = broadcast(fill_value, broadcast_shape) return array.make_array_from_callback(shape, sharding, lambda _: shard) @@ -2137,7 +2138,7 @@ def full_like(x: ArrayLike | DuckTypedArray, if (config.sharding_in_types.value and sharding is None and isinstance(x, Array)): - sharding = x.sharding + sharding = x.aval.sharding else: # If `x` has a sharding but no `_committed` attribute # (in case of ShapeDtypeStruct), default it to True. @@ -4496,7 +4497,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions, broadcast_dimensions=broadcast_dimensions) if config.sharding_in_types.value: if sharding is not None: - assert sharding == aval_out.sharding + assert sharding == aval_out.sharding, (sharding, aval_out.sharding) return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] @@ -5656,7 +5657,7 @@ def _compute_argminmax(value_comparator, get_identity, axis, = axes indices = broadcasted_iota( index_dtype, np.shape(operand), axis, - sharding=operand.sharding if config.sharding_in_types.value else None) + sharding=operand.aval.sharding if config.sharding_in_types.value else None) res = reduce([operand, indices], [get_identity(operand.dtype), np.array(0, index_dtype)], _ArgMinMaxReducer(value_comparator), @@ -6644,7 +6645,7 @@ def _const(example, val): def _zero(x): if config.sharding_in_types.value: return full_like(x, shape=(), fill_value=0, - sharding=x.sharding.with_spec(P())) # type: ignore + sharding=x.aval.sharding.with_spec(P())) # type: ignore return full_like(x, shape=(), fill_value=0) _ones: Callable = partial(full_like, fill_value=1) @@ -6652,7 +6653,7 @@ def _zero(x): def _one(x): if config.sharding_in_types.value: return full_like(x, shape=(), fill_value=1, - sharding=x.sharding.with_spec(P())) + sharding=x.aval.sharding.with_spec(P())) return full_like(x, shape=(), fill_value=1) _twos: Callable = partial(full_like, fill_value=2) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index dd08ca1e91a3..72ac74c38c5d 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -667,7 +667,7 @@ def _one_hot(x: Array, num_classes: int, *, rhs_shape.insert(output_pos_axis, num_classes) if config.sharding_in_types.value: # TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too? - rhs_sharding = NamedSharding(x.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error + rhs_sharding = NamedSharding(x.aval.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error else: rhs_sharding = None rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3c870c259287..dc689b619a6c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5553,7 +5553,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, weak_type = dtype is None and dtypes.is_weakly_typed(object) if (config.sharding_in_types.value and device is None and isinstance(object, Array)): - sharding = object.sharding + sharding = object.aval.sharding else: sharding = canonicalize_device_to_sharding(device) # type: ignore diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 636d8a68f142..c2c92e9785f1 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2838,7 +2838,7 @@ def hidden_axes(fun, *, axes: str | tuple[str, ...] | None = None, def decorator(*args, **kwargs): with mesh_lib.set_abstract_mesh(new_mesh): in_specs = tree_map(lambda a: core.modify_spec_for_hidden( - a.sharding.spec, new_mesh), args) + a.aval.sharding.spec, new_mesh), args) args = mesh_cast(args, in_specs) out = fun(*args, **kwargs) return mesh_cast(out, out_shardings) @@ -2859,7 +2859,7 @@ def decorator(*args, **kwargs): args = mesh_cast(args, in_shardings) out = fun(*args, **kwargs) out_specs = tree_map(lambda o: core.modify_spec_for_hidden( - o.sharding.spec, mesh_lib.get_abstract_mesh()), out) + o.aval.sharding.spec, mesh_lib.get_abstract_mesh()), out) return mesh_cast(out, out_specs) return decorator diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5196ac746be5..78bd4e0ec74f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4788,15 +4788,21 @@ def test_basic_mul(self, mesh): s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) - @jax.jit def f(x): - self.assertEqual(x.sharding.spec, s.spec) + self.assertEqual(x.aval.sharding.spec, s.spec) x = x * 2 - self.assertEqual(x.sharding.spec, s.spec) + self.assertEqual(x.aval.sharding.spec, s.spec) x = x * x - self.assertEqual(x.sharding.spec, s.spec) + self.assertEqual(x.aval.sharding.spec, s.spec) return x + # Eager mode + out = f(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2)) + + f = jax.jit(f) + out = f(arr) self.assertEqual(out.sharding, s) self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2)) @@ -4832,9 +4838,9 @@ def test_fully_replicated_array_mul(self, mesh): @jax.jit def f(x, y): - self.assertEqual(x.sharding.spec, s.spec) + self.assertEqual(x.aval.sharding.spec, s.spec) out = x * y - self.assertEqual(out.sharding.spec, s.spec) + self.assertEqual(out.aval.sharding.spec, s.spec) return out out = f(arr1, arr2) @@ -4876,16 +4882,21 @@ def test_dot_general(self, spec1, spec2, out_spec, collective_name, mesh): arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1)) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2)) - @jax.jit def f(x, y): out = x @ y - self.assertEqual(out.sharding.spec, out_spec) + self.assertEqual(out.aval.sharding.spec, out_spec) return out out = f(arr1, arr2) self.assertArraysEqual(out, np_inp1 @ np_inp1.T) self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) + f = jax.jit(f) + + out = f(arr1, arr2) + self.assertArraysEqual(out, np_inp1 @ np_inp1.T) + self.assertEqual(out.sharding, NamedSharding(mesh, out_spec)) + lowered = f.lower(arr1, arr2) self.check_wsc_in_lowered(lowered.as_text()) @@ -4912,16 +4923,21 @@ def test_dot_general_out_sharding(self, mesh): arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x'))) - @jax.jit def f(x, y): out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None)) - self.assertEqual(out.sharding.spec, P('x', None)) + self.assertEqual(out.aval.sharding.spec, P('x', None)) return jnp.sum(out) out = f(arr1, arr2) self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T)) self.assertEqual(out.sharding, NamedSharding(mesh, P())) + f = jax.jit(f) + + out = f(arr1, arr2) + self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T)) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + out = jax.grad(f, argnums=(0, 1))(arr1, arr2) self.assertEqual(out[0].sharding, arr1.sharding) self.assertEqual(out[1].sharding, arr2.sharding) @@ -4999,6 +5015,16 @@ def test_aval_repr(self, mesh): aval = aval.update(sharding=NamedSharding(mesh, P(('model', 'data'), None))) self.assertEqual(aval.str_short(), 'float32[128@(model,data),64]') + @jtu.with_user_mesh((2, 1), ('x', 'y')) + def test_jnp_ones_mesh_context_eager(self, mesh): + s = NamedSharding(mesh, P('x', None)) + out = jnp.ones((8, 2), dtype=jnp.int32, device=s) + self.assertEqual(out.sharding, s) + + s = NamedSharding(mesh, P('x', 'y')) + out = jnp.ones((8, 2), dtype=jnp.int32, device=s) + self.assertEqual(out.sharding, s) + @parameterized.named_parameters( ('all', None, P('x', 'y'), P(), True), ('first', 0, P('x', 'y'), P('y'), True), @@ -5014,9 +5040,9 @@ def test_reduce_sum(self, axis, in_spec, out_spec, reduce, mesh): @jax.jit def f(x): - self.assertEqual(x.sharding.spec, s.spec) + self.assertEqual(x.aval.sharding.spec, s.spec) y = jnp.sum(x, axis=axis) - self.assertEqual(y.sharding.spec, out_spec) + self.assertEqual(y.aval.sharding.spec, out_spec) return y out = f(arr) @@ -5045,9 +5071,9 @@ def test_reduce_max(self, axis, in_spec, out_spec, reduce, mesh): @jax.jit def f(x): - self.assertEqual(x.sharding.spec, s.spec) + self.assertEqual(x.aval.sharding.spec, s.spec) y = jnp.max(x, axis=axis) - self.assertEqual(y.sharding.spec, out_spec) + self.assertEqual(y.aval.sharding.spec, out_spec) return y out = f(arr) @@ -5090,7 +5116,7 @@ def test_broadcast_in_dim(self, axis, out_spec, mesh): @jax.jit def f(x): y = jnp.expand_dims(x, axis=axis) - self.assertEqual(y.sharding.spec, out_spec) + self.assertEqual(y.aval.sharding.spec, out_spec) return y out = f(arr) @@ -5113,7 +5139,7 @@ def test_integer_pow(self, pow, mesh): @jax.jit def f(x): y = x ** pow - self.assertEqual(y.sharding.spec, s.spec) + self.assertEqual(y.aval.sharding.spec, s.spec) return y out = f(arr) @@ -5136,7 +5162,7 @@ def f(x, y): return x + y with self.assertRaisesRegex( - ValueError, "For primitive add, context mesh.*aval mesh"): + ValueError, "For primitive.*context mesh.*aval mesh"): f(arr1, arr2) @jtu.with_user_mesh((2, 2), ('x', 'y')) @@ -5148,7 +5174,7 @@ def test_sin_unop(self, mesh): @jax.jit def f(x): y = lax.sin(x) - self.assertEqual(y.sharding.spec, s.spec) + self.assertEqual(y.aval.sharding.spec, s.spec) return y out = f(arr) @@ -5168,7 +5194,7 @@ def f(x): assert x.dtype == jnp.int32 y = jnp.array(x, dtype=jnp.float32) self.assertEqual(y.dtype, jnp.float32) - self.assertEqual(y.sharding.spec, s.spec) + self.assertEqual(y.aval.sharding.spec, s.spec) return y f(arr) @@ -5182,7 +5208,7 @@ def test_lax_transpose_rule(self, mesh): @jax.jit def f(x): y = jnp.transpose(x, (1, 2, 0)) - self.assertEqual(y.sharding.spec, P('y', 'z', 'x')) + self.assertEqual(y.aval.sharding.spec, P('y', 'z', 'x')) return y out = f(arr) @@ -5201,7 +5227,7 @@ def test_broadcasted_iota_with_sharding(self, mesh): @jax.jit def f(x): y = jax.nn.one_hot(x, 4) - self.assertEqual(y.sharding.spec, P('x', None)) + self.assertEqual(y.aval.sharding.spec, P('x', None)) return y out = f(arr) @@ -5211,7 +5237,7 @@ def f(x): def g(x): x = x * 2 y = jax.lax.broadcasted_iota(x.dtype, (8, 2), 0, sharding=P('x', 'y')) - self.assertEqual(y.sharding.spec, P('x', 'y')) + self.assertEqual(y.aval.sharding.spec, P('x', 'y')) return x, y _, out = g(arr) @@ -5226,8 +5252,8 @@ def test_einsum_with_out_sharding(self, mesh): @jax.jit def f(x, y): out = jnp.einsum('xy,yz->xz', x, y, - out_sharding=NamedSharding(x.sharding.mesh, P('x', None))) - self.assertEqual(out.sharding.spec, P('x', None)) + out_sharding=NamedSharding(x.aval.sharding.mesh, P('x', None))) + self.assertEqual(out.aval.sharding.spec, P('x', None)) return out out = f(arr1, arr2) @@ -5240,7 +5266,7 @@ def f(x, y): @jax.jit def g(x, y): out = jnp.einsum('xy,yz->xz', x, y, out_sharding=P('x', None)) - self.assertEqual(out.sharding.spec, P('x', None)) + self.assertEqual(out.aval.sharding.spec, P('x', None)) return out arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) @@ -5270,7 +5296,7 @@ def test_einsum_inverse(self, mesh): def h(x, y): spec = P('x', None, 'y', None) out = jnp.einsum('btd,dhq->bhtq', x, y, out_sharding=spec) - self.assertEqual(out.sharding.spec, spec) + self.assertEqual(out.aval.sharding.spec, spec) return out arr1 = jax.device_put(np_inp.reshape(8, 4, 2), @@ -5315,7 +5341,7 @@ def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec, def f(x, new_sharding): y = lax.reshape(x, dst_shape, sharding=new_sharding) y = y * 2 - self.assertEqual(y.sharding.spec, dst_spec) + self.assertEqual(y.aval.sharding.spec, dst_spec) return y new_s = dst_spec if use_sharding_arg else None @@ -5384,7 +5410,7 @@ def test_reshape_split_merge_one_axis(self, src_shape, dst_shape, src_spec, def f(x): y = lax.reshape(x, dst_shape) y = y * 2 - self.assertEqual(y.sharding.spec, dst_spec) + self.assertEqual(y.aval.sharding.spec, dst_spec) return y if error_msg: @@ -5415,7 +5441,7 @@ def test_select(self, mesh): @jax.jit def f(pred, on_true, on_false): y = lax.select(pred, on_true, on_false) - self.assertEqual(y.sharding.spec, s.spec) + self.assertEqual(y.aval.sharding.spec, s.spec) return y out = f(arr1 == arr2, arr1, arr2) @@ -5438,7 +5464,7 @@ def test_mesh_cast_reshard_error(self, mesh): @jax.jit def f(x): - y = mesh_cast(x, NamedSharding(x.sharding.mesh, P('x', None))) + y = mesh_cast(x, NamedSharding(x.aval.sharding.mesh, P('x', None))) return y with self.assertRaisesRegex( @@ -5481,18 +5507,19 @@ def test_shard_map_full_manual(self, mesh): arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) def g(x, y): - self.assertTrue(x.sharding.mesh._are_all_axes_collective) - self.assertTrue(y.sharding.mesh._are_all_axes_collective) + self.assertTrue(x.aval.sharding.mesh._are_all_axes_collective) + self.assertTrue(y.aval.sharding.mesh._are_all_axes_collective) self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective) return x * y @jax.jit def f(x, y): - z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec), + z = shard_map(g, mesh=mesh, + in_specs=(x.aval.sharding.spec, y.aval.sharding.spec), out_specs=P('x', 'y'))(x, y) - self.assertEqual(z.sharding.spec, P('x', 'y')) + self.assertEqual(z.aval.sharding.spec, P('x', 'y')) out = z * 2 - self.assertEqual(out.sharding.spec, P('x', 'y')) + self.assertEqual(out.aval.sharding.spec, P('x', 'y')) return out out = f(arr, arr2) @@ -5506,8 +5533,8 @@ def test_shard_map_dot(self, mesh): arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x'))) def g(x, y): - self.assertTrue(x.sharding.mesh._are_all_axes_collective) - self.assertTrue(y.sharding.mesh._are_all_axes_collective) + self.assertTrue(x.aval.sharding.mesh._are_all_axes_collective) + self.assertTrue(y.aval.sharding.mesh._are_all_axes_collective) self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective) allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True) z = x @ allgatherd_y @@ -5515,11 +5542,12 @@ def g(x, y): @jax.jit def f(x, y): - z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec), + z = shard_map(g, mesh=mesh, + in_specs=(x.aval.sharding.spec, y.aval.sharding.spec), out_specs=P('x', None))(x, y) - self.assertEqual(z.sharding.spec, P('x', None)) + self.assertEqual(z.aval.sharding.spec, P('x', None)) out = z * 2 - self.assertEqual(out.sharding.spec, P('x', None)) + self.assertEqual(out.aval.sharding.spec, P('x', None)) return out out = f(arr, arr2) @@ -5534,7 +5562,7 @@ def test_slice(self, mesh): @jax.jit def f(x): y = lax.slice(x, (0, 0), (4, 3)) - self.assertEqual(y.sharding.spec, P('x', None)) + self.assertEqual(y.aval.sharding.spec, P('x', None)) return y out = f(arr) @@ -5565,7 +5593,7 @@ def test_squeeze(self, mesh): @jax.jit def f(x): y = lax.squeeze(x, (2,)) - self.assertEqual(y.sharding.spec, P('x', None)) + self.assertEqual(y.aval.sharding.spec, P('x', None)) return y out = f(arr) @@ -5591,7 +5619,7 @@ def test_pad(self, mesh): @partial(jax.jit, static_argnums=(1, 2)) def f(x, padding_config, spec): y = lax.pad(x, 0., padding_config) - self.assertEqual(y.sharding.spec, spec) + self.assertEqual(y.aval.sharding.spec, spec) return y out = f(arr, ((2, 2, 0),), P('x')) @@ -5639,7 +5667,7 @@ def f(x, y, method='jnp'): else: assert method == 'lax' y = lax.concatenate([x, y], dimension=1) - self.assertEqual(y.sharding.spec, P('x', 'y')) + self.assertEqual(y.aval.sharding.spec, P('x', 'y')) return y out = f(arr1, arr2) @@ -5677,14 +5705,14 @@ def test_scan(self, mesh): @jax.jit def f(carry, xs): def g(carry, x): - self.assertEqual(carry.sharding.spec, P(None, 'x')) - self.assertEqual(x.sharding.spec, P('x', 'y')) + self.assertEqual(carry.aval.sharding.spec, P(None, 'x')) + self.assertEqual(x.aval.sharding.spec, P('x', 'y')) y = carry @ x - self.assertEqual(y.sharding.spec, P(None, 'y')) + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jax.nn.relu(y) - self.assertEqual(z.sharding.spec, P(None, 'y')) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) a = z @ x.T - self.assertEqual(a.sharding.spec, P(None, 'x')) + self.assertEqual(a.aval.sharding.spec, P(None, 'x')) return a, y return jax.lax.scan(g, carry, xs) @@ -5714,9 +5742,9 @@ def test_argminmax(self, mesh): @jax.jit def f(x): z = jnp.argmax(x, axis=0) - self.assertEqual(z.sharding.spec, P('y')) + self.assertEqual(z.aval.sharding.spec, P('y')) a = jnp.argmin(x, axis=1) - self.assertEqual(a.sharding.spec, P('x')) + self.assertEqual(a.aval.sharding.spec, P('x')) return z, a out1, out2 = f(arr) @@ -5734,11 +5762,11 @@ def test_only_auto(self, mesh): @jax.jit def f(x, x2): y = x * 2 - self.assertEqual(y.sharding.spec, P(None, None)) + self.assertEqual(y.aval.sharding.spec, P(None, None)) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P(None, None)) + self.assertEqual(z.aval.sharding.spec, P(None, None)) a = z @ x2 - self.assertEqual(a.sharding.spec, P(None, None)) + self.assertEqual(a.aval.sharding.spec, P(None, None)) return a out = f(arr, arr.T) @@ -5819,13 +5847,13 @@ def f(x): y = x * 2 with use_hidden_axes('x', 'y'): y = mesh_cast(y, P(None, None)) - self.assertEqual(y.sharding.spec, P(None, None)) + self.assertEqual(y.aval.sharding.spec, P(None, None)) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P(None, None)) + self.assertEqual(z.aval.sharding.spec, P(None, None)) a = z @ z.T - self.assertEqual(a.sharding.spec, P(None, None)) + self.assertEqual(a.aval.sharding.spec, P(None, None)) a = mesh_cast(a, P('x', None)) - self.assertEqual(a.sharding.spec, P('x', None)) + self.assertEqual(a.aval.sharding.spec, P('x', None)) return a out = f(arr) @@ -5847,13 +5875,13 @@ def f(x): y = x * 2 with use_visible_axes('x', 'y'): y = mesh_cast(y, P(None, 'y')) - self.assertEqual(y.sharding.spec, P(None, 'y')) + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P(None, 'y')) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) a = z @ z.T - self.assertEqual(a.sharding.spec, P(None, None)) + self.assertEqual(a.aval.sharding.spec, P(None, None)) a = mesh_cast(a, P(None, None)) - self.assertEqual(a.sharding.spec, P(None, None)) + self.assertEqual(a.aval.sharding.spec, P(None, None)) return a out = f(arr) @@ -5873,13 +5901,13 @@ def f(x): y = x * 2 with use_hidden_axes('x'): y = mesh_cast(y, P(None, 'y')) - self.assertEqual(y.sharding.spec, P(None, 'y')) + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P(None, 'y')) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) a = z @ z.T - self.assertEqual(a.sharding.spec, P(None, None)) + self.assertEqual(a.aval.sharding.spec, P(None, None)) a = mesh_cast(a, P('x', None)) - self.assertEqual(a.sharding.spec, P('x', None)) + self.assertEqual(a.aval.sharding.spec, P('x', None)) return a out = f(arr) @@ -5915,8 +5943,8 @@ def test_split(self, mesh): @partial(jax.jit, static_argnums=(1, 2)) def f(x, sizes=(4, 4), axis=0): ys = lax.split(x, sizes, axis=axis) - self.assertEqual(ys[0].sharding.spec, P('x', 'y')) - self.assertEqual(ys[1].sharding.spec, P('x', 'y')) + self.assertEqual(ys[0].aval.sharding.spec, P('x', 'y')) + self.assertEqual(ys[1].aval.sharding.spec, P('x', 'y')) return ys f(arr) @@ -6010,7 +6038,7 @@ def test_out_sharding_mix_axis_types(self, mesh): @jax.jit def f(x): y = x * 2 - self.assertEqual(y.sharding.spec, P('x', None, None)) + self.assertEqual(y.aval.sharding.spec, P('x', None, None)) return y out = f(arr) @@ -6032,18 +6060,18 @@ def test_auto_mode_mix(self, mesh): @partial(hidden_axes, axes='x', out_shardings=P('x', None)) def h(y): - self.assertEqual(y.sharding.spec, P(None, 'y')) + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P(None, 'y')) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) a = z @ z.T - self.assertEqual(a.sharding.spec, P(None, None)) + self.assertEqual(a.aval.sharding.spec, P(None, None)) return a @jax.jit def g(x): y = x * 2 a = h(y) - self.assertEqual(a.sharding.spec, P('x', None)) + self.assertEqual(a.aval.sharding.spec, P('x', None)) return a out = g(arr) @@ -6063,18 +6091,18 @@ def test_full_user_mode(self, mesh): # No axes specified means full visible mode. @partial(visible_axes, in_shardings=P('x', 'y')) def h(y): - self.assertEqual(y.sharding.spec, P('x', 'y')) + self.assertEqual(y.aval.sharding.spec, P('x', 'y')) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P('x', 'y')) + self.assertEqual(z.aval.sharding.spec, P('x', 'y')) a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', None)) - self.assertEqual(a.sharding.spec, P('x', None)) + self.assertEqual(a.aval.sharding.spec, P('x', None)) return a @jax.jit def f(x): y = x * 2 a = h(y) - self.assertEqual(a.sharding.spec, P(None, None)) + self.assertEqual(a.aval.sharding.spec, P(None, None)) return a out = f(arr) @@ -6093,18 +6121,18 @@ def test_mix_to_full_user_mode(self, mesh): @partial(visible_axes, axes='y', in_shardings=P('x', 'y')) def h(y): - self.assertEqual(y.sharding.spec, P('x', 'y')) + self.assertEqual(y.aval.sharding.spec, P('x', 'y')) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P('x', 'y')) + self.assertEqual(z.aval.sharding.spec, P('x', 'y')) a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P('x', 'y')) - self.assertEqual(a.sharding.spec, P('x', 'y')) + self.assertEqual(a.aval.sharding.spec, P('x', 'y')) return a @jax.jit def f(x): y = x * 2 a = h(y) - self.assertEqual(a.sharding.spec, P('x', None)) + self.assertEqual(a.aval.sharding.spec, P('x', None)) return a out = f(arr) @@ -6119,18 +6147,18 @@ def test_full_auto_to_partial_user(self, mesh): @partial(visible_axes, axes='y', in_shardings=P(None, 'y')) def h(y): - self.assertEqual(y.sharding.spec, P(None, 'y')) + self.assertEqual(y.aval.sharding.spec, P(None, 'y')) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P(None, 'y')) + self.assertEqual(z.aval.sharding.spec, P(None, 'y')) a = jnp.einsum('ab,bc->ac', z, z.T, out_sharding=P(None, 'y')) - self.assertEqual(a.sharding.spec, P(None, 'y')) + self.assertEqual(a.aval.sharding.spec, P(None, 'y')) return a @jax.jit def f(x): y = x * 2 a = h(y) - self.assertEqual(a.sharding.spec, P(None, None)) + self.assertEqual(a.aval.sharding.spec, P(None, None)) return a out = f(arr) @@ -6147,7 +6175,7 @@ def test_auto_gather_out_sharding(self, mesh): def f(embed_vd, token_bt): out = embed_vd.at[token_bt].get(out_sharding=P('x', None, None)) self.assertEqual(out.shape, (8, 4, 16)) - self.assertEqual(out.sharding.spec, P('x', None, None)) + self.assertEqual(out.aval.sharding.spec, P('x', None, None)) return out out = f(embed, tok) @@ -6213,11 +6241,11 @@ def test_full_auto_outside_jit(self, mesh): @jax.jit def f(x): y = x * 2 - self.assertEqual(y.sharding.spec, P(None, None)) + self.assertEqual(y.aval.sharding.spec, P(None, None)) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P(None, None)) + self.assertEqual(z.aval.sharding.spec, P(None, None)) a = z @ z.T - self.assertEqual(a.sharding.spec, P(None, None)) + self.assertEqual(a.aval.sharding.spec, P(None, None)) return a hf = hidden_axes(f, axes=('x', 'y'), out_shardings=P('x', 'y')) @@ -6234,9 +6262,9 @@ def test_full_visible_outside_jit(self, mesh): @jax.jit def f(x): y = x * 2 - self.assertEqual(y.sharding.spec, P('x', 'y')) + self.assertEqual(y.aval.sharding.spec, P('x', 'y')) z = jnp.sin(y) - self.assertEqual(z.sharding.spec, P('x', 'y')) + self.assertEqual(z.aval.sharding.spec, P('x', 'y')) return z hf = visible_axes(f, axes=('x', 'y'), in_shardings=P('x', 'y'))