From a158e02b7d1c1a50e53adfec7f48bec69cc0dc5b Mon Sep 17 00:00:00 2001 From: Fabian Mentzer Date: Thu, 28 Nov 2024 05:34:52 -0800 Subject: [PATCH] Reverts cc5036cc18bc585b0d92a4f606956da084effbad PiperOrigin-RevId: 700998046 --- jax/_src/mesh_utils.py | 9 +-------- jax/_src/sharding_impls.py | 1 - tests/mesh_utils_test.py | 6 ------ tests/pjit_test.py | 6 ------ 4 files changed, 1 insertion(+), 21 deletions(-) diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py index d227b1eeeea9..16e34e1afaef 100644 --- a/jax/_src/mesh_utils.py +++ b/jax/_src/mesh_utils.py @@ -705,12 +705,6 @@ def _transpose_trick( *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] ) -def _validate_axis_shapes(axis_shapes: Sequence[int], arg_name: str, - fun_name: str): - if not all(isinstance(s, int) for s in axis_shapes): - raise ValueError( - f'{arg_name} passed to {fun_name} should be a sequence of ints. Got' - f' {axis_shapes}') def create_device_mesh( mesh_shape: Sequence[int], @@ -746,8 +740,7 @@ def create_device_mesh( """ if devices is None: devices = xb.devices() - _validate_axis_shapes(mesh_shape, 'mesh_shape', 'create_device_mesh') - if math.prod(mesh_shape) != len(devices): + if np.prod(mesh_shape) != len(devices): raise ValueError( f'Number of devices {len(devices)} must equal the product ' f'of mesh_shape {mesh_shape}' diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 39d8aedfe7ad..8abe58e52a74 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -1714,7 +1714,6 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], """ if devices is None: devices = xla_bridge.devices() - mesh_utils._validate_axis_shapes(axis_shapes, 'axis_shapes', 'make_mesh') axis_size = math.prod(axis_shapes) if axis_size > len(devices): raise ValueError( diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index d4db8fd3d406..66f1fc9f6cfb 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -353,12 +353,6 @@ def test_create_device_mesh_for_nd_torus( ) self.assertArraysEqual(assignment, expected_assignment_matrix) - def test_create_device_mesh_non_int_error(self): - with self.assertRaisesRegex( - ValueError, - "mesh_shape passed to create_device_mesh should be a sequence of ints"): - mesh_utils.create_device_mesh(((4,), 4)) - @parameterized.named_parameters( ('2x2x1', mock_2x2x1_devices,), ('2x2x4', mock_2x2x4_devices, ), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6bd05536cebc..e541c6346666 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4458,12 +4458,6 @@ def g(x): self.assertEqual(out2.sharding, s) self.assertEqual(out2.dtype, np.float32) - def test_make_mesh_non_int_error(self): - with self.assertRaisesRegex( - ValueError, - "axis_shapes passed to make_mesh should be a sequence of ints"): - jax.make_mesh(((4,), 4), ('x', 'y')) - def test_jnp_array_reshard_error(self): if jax.device_count() < 2: self.skipTest('Requires >=2 devices')