Skip to content

Commit

Permalink
[sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_t…
Browse files Browse the repository at this point in the history
…ypes by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.

Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered.

PiperOrigin-RevId: 715383866
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 14, 2025
1 parent a1bbad6 commit c72ed26
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 10 deletions.
10 changes: 8 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding, TransferToMemoryKind
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
NamedSharding)
from jax._src.layout import Layout, AutoLayout
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
Expand Down Expand Up @@ -2562,9 +2563,14 @@ def __hash__(self):
return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type))

def _sds_aval_mapping(x):
return ShapedArray(
aval = ShapedArray(
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=x.weak_type)
if config.sharding_in_types.value and isinstance(x.sharding, NamedSharding):
return aval.update(sharding=NamedSharding(
x.sharding.mesh.abstract_mesh,
x.sharding.spec._normalized_spec(x.ndim)))
return aval
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping


Expand Down
11 changes: 10 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,8 @@ def check_valid_jaxtype(x):
# TODO(jakevdp): can these be unified further?

def shaped_abstractify(x):
from jax._src.sharding_impls import NamedSharding # type: ignore

typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
Expand All @@ -1431,7 +1433,14 @@ def shaped_abstractify(x):
if hasattr(x, '__jax_array__'):
return shaped_abstractify(x.__jax_array__())
if hasattr(x, 'dtype'):
return ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False))
aval = ShapedArray(np.shape(x), x.dtype,
weak_type=getattr(x, 'weak_type', False))
if (config.sharding_in_types.value and hasattr(x, 'sharding') and
isinstance(x.sharding, NamedSharding)):
return aval.update(sharding=NamedSharding(
x.sharding.mesh.abstract_mesh,
x.sharding.spec._normalized_spec(aval.ndim)))
return aval
raise TypeError(
f"Cannot interpret value of type {typ} as an abstract array; it "
"does not have a dtype attribute")
Expand Down
8 changes: 7 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6552,7 +6552,13 @@ def _const(example, val):
return np.array(val, dtype)

_zeros: Callable = partial(full_like, fill_value=0)
_zero: Callable = partial(full_like, shape=(), fill_value=0)

def _zero(x):
if config.sharding_in_types.value:
return full_like(x, shape=(), fill_value=0,
sharding=x.sharding.with_spec(P())) # type: ignore
return full_like(x, shape=(), fill_value=0)

_ones: Callable = partial(full_like, fill_value=1)

def _one(x):
Expand Down
16 changes: 11 additions & 5 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from jax._src import dispatch
from jax._src import config
from jax._src import dtypes
from jax._src import mesh as mesh_lib
from jax._src.util import safe_zip

zip, unsafe_zip = safe_zip, zip
Expand All @@ -46,6 +47,13 @@ def standard_primitive(shape_rule, dtype_rule, name,

def _get_array_abstraction_level(a): return a.array_abstraction_level

def call_sharding_rule(rule, num_out, *avals, **kwargs):
if config.sharding_in_types.value:
if rule is None and mesh_lib.get_abstract_mesh()._are_all_axes_auto: # type: ignore
return None if num_out is None else [None] * num_out
return rule(*avals, **kwargs)
return None if num_out is None else [None] * num_out

def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
sharding_rule, *avals, **kwargs):
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
Expand All @@ -57,8 +65,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
out_aval = core.ShapedArray(
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
weak_type=weak_type,
sharding=(sharding_rule(*avals, **kwargs)
if config.sharding_in_types.value else None))
sharding=call_sharding_rule(sharding_rule, None, *avals, **kwargs))
core.check_avals_context_mesh([out_aval], prim.name)
return out_aval
elif least_specialized is core.DShapedArray:
Expand All @@ -82,9 +89,8 @@ def standard_multi_result_abstract_eval(
out_shapes = shape_rule(*avals, **kwargs)
out_dtypes = dtype_rule(*avals, **kwargs)
core.check_avals_context_mesh(avals, prim.name)
out_shardings = (sharding_rule(*avals, **kwargs)
if config.sharding_in_types.value else
[None] * len(out_shapes))
out_shardings = call_sharding_rule(
sharding_rule, len(out_shapes), *avals, **kwargs)
out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh)
for s, d, weak_type, sh in zip(out_shapes, out_dtypes,
weak_types, out_shardings)]
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ def update_axis_types(self, new_axis_types) -> AbstractMesh:
new_axis_types = axis_types_to_names(updated_name_to_type)
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)

@property
def abstract_mesh(self):
return self

@functools.cached_property
def _are_all_axes_collective(self) -> bool:
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
Expand Down
5 changes: 4 additions & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4776,7 +4776,8 @@ def f(x):
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))

lowered_text = f.lower(arr).as_text()
sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=s)
lowered_text = f.lower(sds).as_text()
if config.use_shardy_partitioner.value:
self.assertEqual(lowered_text.count('sdy.sharding_constraint'), 3)
else:
Expand All @@ -4793,6 +4794,8 @@ def g(x):
out = jax.jit(jax.grad(g))(arr)
self.assertEqual(out.sharding, arr.sharding)

jax.jit(jax.grad(g)).lower(sds) # doesn't crash

@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_fully_replicated_array_mul(self, mesh):
np_inp1 = np.arange(16).reshape(8, 2)
Expand Down

0 comments on commit c72ed26

Please sign in to comment.