From 66d3fe05afd1df323edf39ac4afa0b627119442c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 19 Nov 2023 11:12:04 -0800 Subject: [PATCH] [shard_map] add eager axis_index implementation and tests --- jax/experimental/shard_map.py | 9 +++--- tests/shard_map_test.py | 56 +++++++++++++++++++++++++++++------ 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index a4a513bfb861..ff26253d3cd7 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -774,11 +774,8 @@ def post_process_custom_vjp_call(self, out_tracers, _): "a feature request at https://github.com/google/jax/issues !") def process_axis_index(self, frame): - raise NotImplementedError( - "Eager evaluation of an `axis_index` inside a `shard_map` isn't yet " - "supported. " - "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/google/jax/issues !") + with core.eval_context(), jax.disable_jit(False): + return jax.jit(lambda: jax.lax.axis_index(frame.name))() class ShardMapTracer(core.Tracer): @@ -812,6 +809,7 @@ def __str__(self) -> str: return '\n'.join( f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) + __repr__ = __str__ # for debuggers, like `p x` def _prim_applier(prim, params_tup, mesh, *args): def apply(*args): @@ -1641,6 +1639,7 @@ def full_lower(self) -> RewriteTracer: def __str__(self) -> str: return str(self.val) # TODO(mattjj): could show replication info here + __repr__ = __str__ # for debuggers, like `p x` class RewriteTrace(core.Trace): mesh: Mesh diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 6be4c23cfdff..323372cf4773 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -524,11 +524,11 @@ def f(x): self.assertIn('out_names', e.params) self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) + @parameterized.parameters([True, False]) @jtu.run_on_devices('cpu', 'gpu', 'tpu') - def test_debug_print_jit(self): + def test_debug_print_jit(self, jit): mesh = Mesh(jax.devices(), ('i',)) - @jax.jit # NOTE: axis_index requires jit (at time of writing) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): idx = jax.lax.axis_index('i') @@ -537,6 +537,9 @@ def f(x): jax.debug.print("instance {i} has value y={y}", i=idx, y=y) return y + if jit: + f = jax.jit(f) + x = jnp.arange(2 * len(jax.devices())) with jtu.capture_stdout() as output: @@ -699,15 +702,50 @@ def foo_bwd(_, y_bar): with self.assertRaisesRegex(NotImplementedError, 'custom_vjp'): g(x) - def test_eager_notimplemented_error_message_axis_index(self): - def foo(x): - return x + jax.lax.axis_index('x') + @parameterized.parameters([True, False]) + def test_axis_index_basic(self, jit): + def foo(): + return jax.lax.axis_index('x')[None] + + if jit: + foo = jax.jit(foo) mesh = jtu.create_global_mesh((4,), ('x',)) - g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) - x = jnp.arange(4.) - with self.assertRaisesRegex(NotImplementedError, 'axis_index'): - g(x) + ans = shard_map(foo, mesh, in_specs=(), out_specs=P('x'))() + expected = jnp.arange(4.) + self.assertAllClose(ans, expected, check_dtypes=False) + + @parameterized.parameters([True, False]) + def test_axis_index_twoaxes(self, jit): + def foo(): + out1 = jax.lax.axis_index('i')[None, None] + out2 = jax.lax.axis_index('j')[None, None] + out3 = jax.lax.axis_index(('i', 'j'))[None, None] + return out1, out2, out3 + + if jit: + foo = jax.jit(foo) + + mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + ans1, ans2, ans3 = shard_map(foo, mesh, in_specs=(), + out_specs=P('i', 'j'))() + expected1 = jnp.arange(4.)[:, None] + jnp.zeros((4, 2)) + expected2 = jnp.arange(2.)[None, :] + jnp.zeros((4, 2)) + expected3 = jnp.arange(8.).reshape(4, 2) + self.assertAllClose(ans1, expected1, check_dtypes=False) + self.assertAllClose(ans2, expected2, check_dtypes=False) + self.assertAllClose(ans3, expected3, check_dtypes=False) + + def test_axis_index_eager(self): + mesh = jtu.create_global_mesh((4,), ('x',)) + + @partial(shard_map, mesh=mesh, in_specs=(), out_specs=P()) + def foo(): + val = jax.lax.psum(jax.lax.axis_index('x'), 'x') + return 1. if val > 0 else -1. + + out = foo() # doesn't crash + self.assertEqual(out, 1.) def test_jaxpr_shardings_with_no_outputs(self): # https://github.com/google/jax/issues/15385