Skip to content

Commit

Permalink
Make make_mesh take visible_axes, hidden_axes and `collective_a…
Browse files Browse the repository at this point in the history
…xes` as parameters instead of `axis_types` to make it a more cleaner API.

The mesh axis names provided to those parameters should be disjoint i.e. no overlap.

PiperOrigin-RevId: 717718333
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 21, 2025
1 parent bba5ada commit 9e99aeb
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 9 deletions.
37 changes: 36 additions & 1 deletion jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,10 +1783,44 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
f' {s} or type: {sharding.mesh._name_to_type[s]}')
return sharding

TypeOfAxis = str | tuple[str, ...] | None

def _normalize(axes: TypeOfAxis = None) -> tuple[str, ...]:
if axes is None:
return ()
return (axes,) if isinstance(axes, str) else axes

def _get_axis_types(
hidden_axes: TypeOfAxis = None, visible_axes: TypeOfAxis = None,
collective_axes: TypeOfAxis = None):
if hidden_axes is None and visible_axes is None and collective_axes is None:
return None

hidden_axes = _normalize(hidden_axes)
visible_axes = _normalize(visible_axes)
collective_axes = _normalize(collective_axes)

ha, va, ca = set(hidden_axes), set(visible_axes), set(collective_axes)
disjoint = ha.isdisjoint(va) and ha.isdisjoint(ca) and va.isdisjoint(ca)
if not disjoint:
raise ValueError(
f'{hidden_axes=}, {visible_axes=} and {collective_axes=} should be'
' non-overlapping.')

out = {}
if hidden_axes:
out.update({mesh_lib.AxisTypes.Hidden: hidden_axes})
if visible_axes:
out.update({mesh_lib.AxisTypes.Visible: visible_axes})
if collective_axes:
out.update({mesh_lib.AxisTypes.Collective: collective_axes})
return out


def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
*, devices: Sequence[xc.Device] | None = None,
axis_types: mesh_lib.MeshAxisType | None = None) -> mesh_lib.Mesh:
hidden_axes: TypeOfAxis = None, visible_axes: TypeOfAxis = None,
collective_axes: TypeOfAxis = None) -> mesh_lib.Mesh:
"""Creates an efficient mesh with the shape and axis names specified.
This function attempts to automatically compute a good mapping from a set of
Expand Down Expand Up @@ -1848,4 +1882,5 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
mesh_devices = mesh_utils.create_device_mesh(
new_axis_shapes, devices,
allow_split_physical_axes=allow_split_physical_axes)
axis_types = _get_axis_types(hidden_axes, visible_axes, collective_axes)
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)
10 changes: 9 additions & 1 deletion jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,15 @@ def create_mesh(mesh_shape, axis_names, iota_order=False, axis_types=None):
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
return jax.sharding.Mesh(mesh_devices, axis_names, axis_types=axis_types)
else:
return jax.make_mesh(mesh_shape, axis_names, axis_types=axis_types)
if axis_types is None:
visible_axes = hidden_axes = collective_axes = None
else:
visible_axes = axis_types.get(mesh_lib.AxisTypes.Visible, None)
hidden_axes = axis_types.get(mesh_lib.AxisTypes.Hidden, None)
collective_axes = axis_types.get(mesh_lib.AxisTypes.Collective, None)
return jax.make_mesh(mesh_shape, axis_names, visible_axes=visible_axes,
hidden_axes=hidden_axes,
collective_axes=collective_axes)

class _cached_property:
null = object()
Expand Down
43 changes: 43 additions & 0 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import dialects, ir
from jax._src.util import safe_zip
from jax._src.mesh import AxisTypes
from jax._src.sharding import common_devices_indices_map
from jax._src.sharding_impls import (
_op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map,
Expand Down Expand Up @@ -1313,6 +1314,48 @@ def test_mesh_axis_types_mismatch(self):
jax.sharding.AbstractMesh((('x', 2), ('y', 1)),
axis_types={jax.sharding.AxisTypes.Hidden: 'x'})

def test_make_mesh_axis_types(self):
mesh = jax.make_mesh((1, 1), ('x', 'y'))
self.assertDictEqual(mesh.axis_types, {AxisTypes.Hidden: ('x', 'y')})

mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), visible_axes='x',
hidden_axes='y', collective_axes='z')
self.assertDictEqual(
mesh.axis_types, {AxisTypes.Hidden: ('y',), AxisTypes.Visible: ('x',),
AxisTypes.Collective: ('z',)})

mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), visible_axes=('x', 'y'),
collective_axes='z')
self.assertDictEqual(mesh.axis_types, {AxisTypes.Visible: ('x', 'y'),
AxisTypes.Collective: ('z',)})

mesh = jax.make_mesh((1, 1), ('x', 'y'), visible_axes=('x', 'y'))
self.assertDictEqual(mesh.axis_types, {AxisTypes.Visible: ('x', 'y')})

mesh = jax.make_mesh((1,), 'model', collective_axes='model')
self.assertDictEqual(mesh.axis_types, {AxisTypes.Collective: ('model',)})

with self.assertRaisesRegex(ValueError, "should be non-overlapping"):
jax.make_mesh((1, 1, 1), ('data', 'model', 'seq'),
hidden_axes='data', visible_axes=('data', 'seq'),
collective_axes='model')

with self.assertRaisesRegex(ValueError, "should be non-overlapping"):
jax.make_mesh((1, 1, 1), ('data', 'model', 'seq'),
hidden_axes='data', visible_axes='model',
collective_axes='data')

with self.assertRaisesRegex(ValueError, "should be non-overlapping"):
jax.make_mesh((1, 1, 1), ('data', 'model', 'seq'),
visible_axes=('data', 'seq'),
collective_axes=('seq', 'model'))

with self.assertRaisesRegex(
ValueError,
'Number of axis names in axis_types should match the number of'
' axis_names'):
jax.make_mesh((1, 1), ('data', 'model'), visible_axes='data')


@jtu.with_config(jax_use_shardy_partitioner=True)
class ShardyShardingTest(jtu.JaxTestCase):
Expand Down
17 changes: 10 additions & 7 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5975,7 +5975,8 @@ def f(x):
x = mesh_cast(x, P(None, None))
return x

self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.Visible: 'x'})
self.assertDictEqual(arr.sharding.mesh.axis_types,
{AxisTypes.Visible: ('x',)})
out = f(arr)
self.assertArraysEqual(out, np_inp)
self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'})
Expand All @@ -5986,26 +5987,28 @@ def test_inputs_different_context(self, mesh):
s = NamedSharding(mesh, P('x'))
arr = jax.device_put(np_inp, s)

auto_mesh = jax.make_mesh((2,), 'x', axis_types={AxisTypes.Hidden: 'x'})
auto_mesh = jax.make_mesh((2,), 'x', hidden_axes='x')
with mesh_lib.use_mesh(auto_mesh):
arr2 = jnp.ones(8)
self.assertDictEqual(arr2.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'})
self.assertDictEqual(arr2.sharding.mesh.axis_types,
{AxisTypes.Hidden: ('x',)})

@jax.jit
def f(x, y):
return x, y

out1, out2 = f(arr, arr2)
self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.Visible: 'x'})
self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'})
self.assertDictEqual(out1.sharding.mesh.axis_types,
{AxisTypes.Visible: ('x',)})
self.assertDictEqual(out2.sharding.mesh.axis_types,
{AxisTypes.Hidden: ('x',)})

@jtu.with_user_mesh((2,), 'x')
def test_output_different_context_error(self, mesh):
np_inp1 = np.arange(16).reshape(8, 2)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
auto_mesh = jax.make_mesh((2,), 'x',
axis_types={AxisTypes.Hidden: 'x'}).abstract_mesh
auto_mesh = jax.make_mesh((2,), 'x', hidden_axes='x').abstract_mesh

@jax.jit
def f(x, y):
Expand Down

0 comments on commit 9e99aeb

Please sign in to comment.