From 9e99aebdc1dabfc90371bff882bd08599b465793 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 20 Jan 2025 20:47:10 -0800 Subject: [PATCH] Make `make_mesh` take `visible_axes`, `hidden_axes` and `collective_axes` as parameters instead of `axis_types` to make it a more cleaner API. The mesh axis names provided to those parameters should be disjoint i.e. no overlap. PiperOrigin-RevId: 717718333 --- jax/_src/sharding_impls.py | 37 +++++++++++++++++++++++++++++++- jax/_src/test_util.py | 10 ++++++++- tests/array_test.py | 43 ++++++++++++++++++++++++++++++++++++++ tests/pjit_test.py | 17 ++++++++------- 4 files changed, 98 insertions(+), 9 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index b2a482be5450..bd532d1051f5 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1783,10 +1783,44 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, f' {s} or type: {sharding.mesh._name_to_type[s]}') return sharding +TypeOfAxis = str | tuple[str, ...] | None + +def _normalize(axes: TypeOfAxis = None) -> tuple[str, ...]: + if axes is None: + return () + return (axes,) if isinstance(axes, str) else axes + +def _get_axis_types( + hidden_axes: TypeOfAxis = None, visible_axes: TypeOfAxis = None, + collective_axes: TypeOfAxis = None): + if hidden_axes is None and visible_axes is None and collective_axes is None: + return None + + hidden_axes = _normalize(hidden_axes) + visible_axes = _normalize(visible_axes) + collective_axes = _normalize(collective_axes) + + ha, va, ca = set(hidden_axes), set(visible_axes), set(collective_axes) + disjoint = ha.isdisjoint(va) and ha.isdisjoint(ca) and va.isdisjoint(ca) + if not disjoint: + raise ValueError( + f'{hidden_axes=}, {visible_axes=} and {collective_axes=} should be' + ' non-overlapping.') + + out = {} + if hidden_axes: + out.update({mesh_lib.AxisTypes.Hidden: hidden_axes}) + if visible_axes: + out.update({mesh_lib.AxisTypes.Visible: visible_axes}) + if collective_axes: + out.update({mesh_lib.AxisTypes.Collective: collective_axes}) + return out + def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], *, devices: Sequence[xc.Device] | None = None, - axis_types: mesh_lib.MeshAxisType | None = None) -> mesh_lib.Mesh: + hidden_axes: TypeOfAxis = None, visible_axes: TypeOfAxis = None, + collective_axes: TypeOfAxis = None) -> mesh_lib.Mesh: """Creates an efficient mesh with the shape and axis names specified. This function attempts to automatically compute a good mapping from a set of @@ -1848,4 +1882,5 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], mesh_devices = mesh_utils.create_device_mesh( new_axis_shapes, devices, allow_split_physical_axes=allow_split_physical_axes) + axis_types = _get_axis_types(hidden_axes, visible_axes, collective_axes) return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 69bb34669a8a..3017890e6e04 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1569,7 +1569,15 @@ def create_mesh(mesh_shape, axis_names, iota_order=False, axis_types=None): mesh_devices = np.array(devices[:size]).reshape(mesh_shape) return jax.sharding.Mesh(mesh_devices, axis_names, axis_types=axis_types) else: - return jax.make_mesh(mesh_shape, axis_names, axis_types=axis_types) + if axis_types is None: + visible_axes = hidden_axes = collective_axes = None + else: + visible_axes = axis_types.get(mesh_lib.AxisTypes.Visible, None) + hidden_axes = axis_types.get(mesh_lib.AxisTypes.Hidden, None) + collective_axes = axis_types.get(mesh_lib.AxisTypes.Collective, None) + return jax.make_mesh(mesh_shape, axis_names, visible_axes=visible_axes, + hidden_axes=hidden_axes, + collective_axes=collective_axes) class _cached_property: null = object() diff --git a/tests/array_test.py b/tests/array_test.py index 8a4cb647bd44..13fb1295a9ec 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -31,6 +31,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import dialects, ir from jax._src.util import safe_zip +from jax._src.mesh import AxisTypes from jax._src.sharding import common_devices_indices_map from jax._src.sharding_impls import ( _op_sharding_to_pos_sharding, pmap_sharding_devices_indices_map, @@ -1313,6 +1314,48 @@ def test_mesh_axis_types_mismatch(self): jax.sharding.AbstractMesh((('x', 2), ('y', 1)), axis_types={jax.sharding.AxisTypes.Hidden: 'x'}) + def test_make_mesh_axis_types(self): + mesh = jax.make_mesh((1, 1), ('x', 'y')) + self.assertDictEqual(mesh.axis_types, {AxisTypes.Hidden: ('x', 'y')}) + + mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), visible_axes='x', + hidden_axes='y', collective_axes='z') + self.assertDictEqual( + mesh.axis_types, {AxisTypes.Hidden: ('y',), AxisTypes.Visible: ('x',), + AxisTypes.Collective: ('z',)}) + + mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), visible_axes=('x', 'y'), + collective_axes='z') + self.assertDictEqual(mesh.axis_types, {AxisTypes.Visible: ('x', 'y'), + AxisTypes.Collective: ('z',)}) + + mesh = jax.make_mesh((1, 1), ('x', 'y'), visible_axes=('x', 'y')) + self.assertDictEqual(mesh.axis_types, {AxisTypes.Visible: ('x', 'y')}) + + mesh = jax.make_mesh((1,), 'model', collective_axes='model') + self.assertDictEqual(mesh.axis_types, {AxisTypes.Collective: ('model',)}) + + with self.assertRaisesRegex(ValueError, "should be non-overlapping"): + jax.make_mesh((1, 1, 1), ('data', 'model', 'seq'), + hidden_axes='data', visible_axes=('data', 'seq'), + collective_axes='model') + + with self.assertRaisesRegex(ValueError, "should be non-overlapping"): + jax.make_mesh((1, 1, 1), ('data', 'model', 'seq'), + hidden_axes='data', visible_axes='model', + collective_axes='data') + + with self.assertRaisesRegex(ValueError, "should be non-overlapping"): + jax.make_mesh((1, 1, 1), ('data', 'model', 'seq'), + visible_axes=('data', 'seq'), + collective_axes=('seq', 'model')) + + with self.assertRaisesRegex( + ValueError, + 'Number of axis names in axis_types should match the number of' + ' axis_names'): + jax.make_mesh((1, 1), ('data', 'model'), visible_axes='data') + @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyShardingTest(jtu.JaxTestCase): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5ccb9a7595ed..7f120f15b865 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5975,7 +5975,8 @@ def f(x): x = mesh_cast(x, P(None, None)) return x - self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.Visible: 'x'}) + self.assertDictEqual(arr.sharding.mesh.axis_types, + {AxisTypes.Visible: ('x',)}) out = f(arr) self.assertArraysEqual(out, np_inp) self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'}) @@ -5986,26 +5987,28 @@ def test_inputs_different_context(self, mesh): s = NamedSharding(mesh, P('x')) arr = jax.device_put(np_inp, s) - auto_mesh = jax.make_mesh((2,), 'x', axis_types={AxisTypes.Hidden: 'x'}) + auto_mesh = jax.make_mesh((2,), 'x', hidden_axes='x') with mesh_lib.use_mesh(auto_mesh): arr2 = jnp.ones(8) - self.assertDictEqual(arr2.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'}) + self.assertDictEqual(arr2.sharding.mesh.axis_types, + {AxisTypes.Hidden: ('x',)}) @jax.jit def f(x, y): return x, y out1, out2 = f(arr, arr2) - self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.Visible: 'x'}) - self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Hidden: 'x'}) + self.assertDictEqual(out1.sharding.mesh.axis_types, + {AxisTypes.Visible: ('x',)}) + self.assertDictEqual(out2.sharding.mesh.axis_types, + {AxisTypes.Hidden: ('x',)}) @jtu.with_user_mesh((2,), 'x') def test_output_different_context_error(self, mesh): np_inp1 = np.arange(16).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x'))) - auto_mesh = jax.make_mesh((2,), 'x', - axis_types={AxisTypes.Hidden: 'x'}).abstract_mesh + auto_mesh = jax.make_mesh((2,), 'x', hidden_axes='x').abstract_mesh @jax.jit def f(x, y):