Skip to content

Commit

Permalink
Don't allow users to query tracer.sharding even under sharding in t…
Browse files Browse the repository at this point in the history
…ypes mode.

Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 20, 2025
1 parent 7f19b34 commit d50d1e2
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 147 deletions.
9 changes: 2 additions & 7 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
NamedSharding)
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
from jax._src.layout import Layout, AutoLayout
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
Expand Down Expand Up @@ -2564,11 +2563,7 @@ def _sds_aval_mapping(x):
aval = ShapedArray(
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=x.weak_type)
if config.sharding_in_types.value and isinstance(x.sharding, NamedSharding):
return aval.update(sharding=NamedSharding(
x.sharding.mesh.abstract_mesh,
x.sharding.spec._normalized_spec(x.ndim)))
return aval
return core.update_aval_with_sharding(aval, x.sharding)
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping


Expand Down
17 changes: 7 additions & 10 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from jax._src.lib import xla_extension as xe
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding,
PmapSharding, SingleDeviceSharding,
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
from jax._src.typing import ArrayLike, DLDeviceType
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
Expand Down Expand Up @@ -753,7 +753,8 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray:
first_value = per_device_values[0]
expected_dtype = first_value.dtype
expected_shape = sharding.shard_shape(shape)
aval = core.ShapedArray(shape, expected_dtype)
aval = core.update_aval_with_sharding(
core.ShapedArray(shape, expected_dtype), sharding)
_validate_shape_and_dtype_for_per_device_arrays(
per_device_values,
expected_shape=expected_shape,
Expand Down Expand Up @@ -1017,7 +1018,8 @@ def make_array_from_single_device_arrays(
raise ValueError(
"jax.make_array_from_single_device_arrays requires a list of concrete"
f" arrays as input. got types {set(map(type, arrays))}")
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
aval = core.update_aval_with_sharding(
core.ShapedArray(shape, arrays[0].dtype, weak_type=False), sharding)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,
committed=True)
Expand All @@ -1028,13 +1030,7 @@ def make_array_from_single_device_arrays(
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity

def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
return self.aval.update(sharding=NamedSharding(
self.sharding.mesh.abstract_mesh,
self.sharding.spec._normalized_spec(self.ndim)))
else:
return self.aval

return core.update_aval_with_sharding(self.aval, self.sharding)
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array

# TODO(jakevdp) replace this with true inheritance at the C++ level.
Expand Down Expand Up @@ -1179,6 +1175,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):


def _array_global_result_handler(global_aval, out_sharding, committed):
global_aval = core.update_aval_with_sharding(global_aval, out_sharding)
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0)
if dtypes.issubdtype(global_aval.dtype, dtypes.extended):
Expand Down
51 changes: 25 additions & 26 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,12 @@ def __getattr__(self, name):
# if the aval property raises an AttributeError, gets caught here
assert not config.enable_checks.value or name != "aval"

if name == 'sharding':
raise AttributeError(
self,
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")

try:
attr = getattr(self.aval, name)
except AttributeError as err:
Expand Down Expand Up @@ -1421,6 +1427,13 @@ def check_valid_jaxtype(x):
raise TypeError(
f"Value {x!r} of type {type(x)} is not a valid JAX type")

def update_aval_with_sharding(aval, sharding):
from jax._src.sharding_impls import NamedSharding # type: ignore
if config.sharding_in_types.value and isinstance(sharding, NamedSharding):
aval = aval.update(sharding=NamedSharding(
sharding.mesh.abstract_mesh, sharding.spec._normalized_spec(aval.ndim)))
return aval


# We have three flavors of abstractification APIs here which each used to have
# their own separate implementation. Now they're effectively the same, with the
Expand All @@ -1433,8 +1446,6 @@ def check_valid_jaxtype(x):
# TODO(jakevdp): can these be unified further?

def shaped_abstractify(x):
from jax._src.sharding_impls import NamedSharding # type: ignore

typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
Expand All @@ -1448,12 +1459,7 @@ def shaped_abstractify(x):
if hasattr(x, 'dtype'):
aval = ShapedArray(np.shape(x), x.dtype,
weak_type=getattr(x, 'weak_type', False))
if (config.sharding_in_types.value and hasattr(x, 'sharding') and
isinstance(x.sharding, NamedSharding)):
return aval.update(sharding=NamedSharding(
x.sharding.mesh.abstract_mesh,
x.sharding.spec._normalized_spec(aval.ndim)))
return aval
return update_aval_with_sharding(aval, getattr(x, 'sharding', None))
raise TypeError(
f"Cannot interpret value of type {typ} as an abstract array; it "
"does not have a dtype attribute")
Expand Down Expand Up @@ -1701,13 +1707,17 @@ def get_sharding(sharding, ndim):
raise ValueError(
"Length of sharding.spec must be equal to aval's ndim. Got"
f" sharding.spec {sharding.spec} and aval.ndim {ndim}")
return _maybe_modify_sharding(sharding)

context_mesh = mesh_lib.get_abstract_mesh()
if not context_mesh:
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
assert sharding is None
return NamedSharding(context_mesh, P(*[None] * ndim))
out_s = _maybe_modify_sharding(sharding)
else:
context_mesh = mesh_lib.get_abstract_mesh()
if not context_mesh:
raise RuntimeError("Please set the mesh via `jax.set_mesh` API.")
assert sharding is None
out_s = NamedSharding(context_mesh, P(*[None] * ndim))
if not isinstance(out_s.mesh, mesh_lib.AbstractMesh):
raise ValueError("Mesh of an aval must be an AbstractMesh. "
f"Got {out_s.mesh} of type {type(out_s.mesh)}")
return out_s


class ShapedArray(UnshapedArray):
Expand All @@ -1720,9 +1730,6 @@ def __init__(self, shape, dtype, weak_type=False, sharding=None):
self.weak_type = weak_type
if config.sharding_in_types.value:
self.sharding = get_sharding(sharding, len(self.shape))
if not isinstance(self.sharding.mesh, mesh_lib.AbstractMesh):
raise ValueError(
f"Mesh of an aval must be an AbstractMesh. Got {self.sharding.mesh}")

def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
if shape is None:
Expand Down Expand Up @@ -1796,14 +1803,6 @@ def _get_shape_sharding_str(shape, spec):
out.append(f"{s1}@{s2}")
return ','.join(out)

def _get_abstract_sharding(val):
from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error

if (config.sharding_in_types.value and hasattr(val, 'sharding') and
isinstance(val.sharding, NamedSharding)):
return NamedSharding(val.sharding.mesh.abstract_mesh,
val.sharding.spec._normalized_spec(val.ndim))
return None

def primal_dtype_to_tangent_dtype(primal_dtype):
if isinstance(primal_dtype, dtypes.ExtendedDType):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ def scan(f, init, xs, length=None):
if not hasattr(x, 'shape')))) from err

if (config.sharding_in_types.value and
not all(x.sharding.spec[0] is None for x in xs_flat)):
not all(x.aval.sharding.spec[0] is None for x in xs_flat)):
raise ValueError('0th dimension of all xs should be replicated. Got '
f'{", ".join(str(x.sharding.spec) for x in xs_flat)}')
f'{", ".join(str(x.aval.sharding.spec) for x in xs_flat)}')

if length is not None:
try:
Expand Down
15 changes: 8 additions & 7 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def _convert_element_type(

if (config.sharding_in_types.value and sharding is None and
isinstance(operand, Array)):
sharding = operand.sharding
sharding = operand.aval.sharding

sharding = canonicalize_sharding(sharding, check_mesh_consistency=False) # type: ignore

Expand Down Expand Up @@ -1920,7 +1920,8 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
fill_value = _convert_element_type(fill_value, dtype, weak_type)
if (sharding is not None and not isinstance(sharding, PmapSharding) and
isinstance(fill_value, array.ArrayImpl)):
isinstance(fill_value, array.ArrayImpl) and
not config.sharding_in_types.value):
broadcast_shape = sharding.shard_shape(shape)
shard = broadcast(fill_value, broadcast_shape)
return array.make_array_from_callback(shape, sharding, lambda _: shard)
Expand Down Expand Up @@ -2137,7 +2138,7 @@ def full_like(x: ArrayLike | DuckTypedArray,

if (config.sharding_in_types.value and sharding is None and
isinstance(x, Array)):
sharding = x.sharding
sharding = x.aval.sharding
else:
# If `x` has a sharding but no `_committed` attribute
# (in case of ShapeDtypeStruct), default it to True.
Expand Down Expand Up @@ -4496,7 +4497,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
broadcast_dimensions=broadcast_dimensions)
if config.sharding_in_types.value:
if sharding is not None:
assert sharding == aval_out.sharding
assert sharding == aval_out.sharding, (sharding, aval_out.sharding)
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]

Expand Down Expand Up @@ -5656,7 +5657,7 @@ def _compute_argminmax(value_comparator, get_identity,
axis, = axes
indices = broadcasted_iota(
index_dtype, np.shape(operand), axis,
sharding=operand.sharding if config.sharding_in_types.value else None)
sharding=operand.aval.sharding if config.sharding_in_types.value else None)
res = reduce([operand, indices],
[get_identity(operand.dtype), np.array(0, index_dtype)],
_ArgMinMaxReducer(value_comparator),
Expand Down Expand Up @@ -6644,15 +6645,15 @@ def _const(example, val):
def _zero(x):
if config.sharding_in_types.value:
return full_like(x, shape=(), fill_value=0,
sharding=x.sharding.with_spec(P())) # type: ignore
sharding=x.aval.sharding.with_spec(P())) # type: ignore
return full_like(x, shape=(), fill_value=0)

_ones: Callable = partial(full_like, fill_value=1)

def _one(x):
if config.sharding_in_types.value:
return full_like(x, shape=(), fill_value=1,
sharding=x.sharding.with_spec(P()))
sharding=x.aval.sharding.with_spec(P()))
return full_like(x, shape=(), fill_value=1)

_twos: Callable = partial(full_like, fill_value=2)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def _one_hot(x: Array, num_classes: int, *,
rhs_shape.insert(output_pos_axis, num_classes)
if config.sharding_in_types.value:
# TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too?
rhs_sharding = NamedSharding(x.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
rhs_sharding = NamedSharding(x.aval.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
else:
rhs_sharding = None
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5553,7 +5553,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
weak_type = dtype is None and dtypes.is_weakly_typed(object)
if (config.sharding_in_types.value and device is None and
isinstance(object, Array)):
sharding = object.sharding
sharding = object.aval.sharding
else:
sharding = canonicalize_device_to_sharding(device) # type: ignore

Expand Down
4 changes: 2 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2838,7 +2838,7 @@ def hidden_axes(fun, *, axes: str | tuple[str, ...] | None = None,
def decorator(*args, **kwargs):
with mesh_lib.set_abstract_mesh(new_mesh):
in_specs = tree_map(lambda a: core.modify_spec_for_hidden(
a.sharding.spec, new_mesh), args)
a.aval.sharding.spec, new_mesh), args)
args = mesh_cast(args, in_specs)
out = fun(*args, **kwargs)
return mesh_cast(out, out_shardings)
Expand All @@ -2859,7 +2859,7 @@ def decorator(*args, **kwargs):
args = mesh_cast(args, in_shardings)
out = fun(*args, **kwargs)
out_specs = tree_map(lambda o: core.modify_spec_for_hidden(
o.sharding.spec, mesh_lib.get_abstract_mesh()), out)
o.aval.sharding.spec, mesh_lib.get_abstract_mesh()), out)
return mesh_cast(out, out_specs)
return decorator

Expand Down
Loading

0 comments on commit d50d1e2

Please sign in to comment.