Skip to content

Commit

Permalink
Reverts cc5036c
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700998046
  • Loading branch information
Fabian Mentzer authored and Google-ML-Automation committed Nov 28, 2024
1 parent 34fe66b commit a158e02
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 21 deletions.
9 changes: 1 addition & 8 deletions jax/_src/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,12 +705,6 @@ def _transpose_trick(
*_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims]
)

def _validate_axis_shapes(axis_shapes: Sequence[int], arg_name: str,
fun_name: str):
if not all(isinstance(s, int) for s in axis_shapes):
raise ValueError(
f'{arg_name} passed to {fun_name} should be a sequence of ints. Got'
f' {axis_shapes}')

def create_device_mesh(
mesh_shape: Sequence[int],
Expand Down Expand Up @@ -746,8 +740,7 @@ def create_device_mesh(
"""
if devices is None:
devices = xb.devices()
_validate_axis_shapes(mesh_shape, 'mesh_shape', 'create_device_mesh')
if math.prod(mesh_shape) != len(devices):
if np.prod(mesh_shape) != len(devices):
raise ValueError(
f'Number of devices {len(devices)} must equal the product '
f'of mesh_shape {mesh_shape}'
Expand Down
1 change: 0 additions & 1 deletion jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,6 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
"""
if devices is None:
devices = xla_bridge.devices()
mesh_utils._validate_axis_shapes(axis_shapes, 'axis_shapes', 'make_mesh')
axis_size = math.prod(axis_shapes)
if axis_size > len(devices):
raise ValueError(
Expand Down
6 changes: 0 additions & 6 deletions tests/mesh_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,6 @@ def test_create_device_mesh_for_nd_torus(
)
self.assertArraysEqual(assignment, expected_assignment_matrix)

def test_create_device_mesh_non_int_error(self):
with self.assertRaisesRegex(
ValueError,
"mesh_shape passed to create_device_mesh should be a sequence of ints"):
mesh_utils.create_device_mesh(((4,), 4))

@parameterized.named_parameters(
('2x2x1', mock_2x2x1_devices,),
('2x2x4', mock_2x2x4_devices, ),
Expand Down
6 changes: 0 additions & 6 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4458,12 +4458,6 @@ def g(x):
self.assertEqual(out2.sharding, s)
self.assertEqual(out2.dtype, np.float32)

def test_make_mesh_non_int_error(self):
with self.assertRaisesRegex(
ValueError,
"axis_shapes passed to make_mesh should be a sequence of ints"):
jax.make_mesh(((4,), 4), ('x', 'y'))

def test_jnp_array_reshard_error(self):
if jax.device_count() < 2:
self.skipTest('Requires >=2 devices')
Expand Down

0 comments on commit a158e02

Please sign in to comment.