Skip to content

Commit

Permalink
Merge pull request #18599 from mattjj:shmap-eager-axis-index
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 584423708
  • Loading branch information
jax authors committed Nov 21, 2023
2 parents 2efa586 + 66d3fe0 commit b48254e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
9 changes: 4 additions & 5 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,11 +774,8 @@ def post_process_custom_vjp_call(self, out_tracers, _):
"a feature request at https://github.com/google/jax/issues !")

def process_axis_index(self, frame):
raise NotImplementedError(
"Eager evaluation of an `axis_index` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
with core.eval_context(), jax.disable_jit(False):
return jax.jit(lambda: jax.lax.axis_index(frame.name))()


class ShardMapTracer(core.Tracer):
Expand Down Expand Up @@ -812,6 +809,7 @@ def __str__(self) -> str:
return '\n'.join(
f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n"
for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks))
__repr__ = __str__ # for debuggers, like `p x`

def _prim_applier(prim, params_tup, mesh, *args):
def apply(*args):
Expand Down Expand Up @@ -1641,6 +1639,7 @@ def full_lower(self) -> RewriteTracer:

def __str__(self) -> str:
return str(self.val) # TODO(mattjj): could show replication info here
__repr__ = __str__ # for debuggers, like `p x`

class RewriteTrace(core.Trace):
mesh: Mesh
Expand Down
56 changes: 47 additions & 9 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,11 @@ def f(x):
self.assertIn('out_names', e.params)
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))

@parameterized.parameters([True, False])
@jtu.run_on_devices('cpu', 'gpu', 'tpu')
def test_debug_print_jit(self):
def test_debug_print_jit(self, jit):
mesh = Mesh(jax.devices(), ('i',))

@jax.jit # NOTE: axis_index requires jit (at time of writing)
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
def f(x):
idx = jax.lax.axis_index('i')
Expand All @@ -537,6 +537,9 @@ def f(x):
jax.debug.print("instance {i} has value y={y}", i=idx, y=y)
return y

if jit:
f = jax.jit(f)

x = jnp.arange(2 * len(jax.devices()))

with jtu.capture_stdout() as output:
Expand Down Expand Up @@ -699,15 +702,50 @@ def foo_bwd(_, y_bar):
with self.assertRaisesRegex(NotImplementedError, 'custom_vjp'):
g(x)

def test_eager_notimplemented_error_message_axis_index(self):
def foo(x):
return x + jax.lax.axis_index('x')
@parameterized.parameters([True, False])
def test_axis_index_basic(self, jit):
def foo():
return jax.lax.axis_index('x')[None]

if jit:
foo = jax.jit(foo)

mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
with self.assertRaisesRegex(NotImplementedError, 'axis_index'):
g(x)
ans = shard_map(foo, mesh, in_specs=(), out_specs=P('x'))()
expected = jnp.arange(4.)
self.assertAllClose(ans, expected, check_dtypes=False)

@parameterized.parameters([True, False])
def test_axis_index_twoaxes(self, jit):
def foo():
out1 = jax.lax.axis_index('i')[None, None]
out2 = jax.lax.axis_index('j')[None, None]
out3 = jax.lax.axis_index(('i', 'j'))[None, None]
return out1, out2, out3

if jit:
foo = jax.jit(foo)

mesh = jtu.create_global_mesh((4, 2), ('i', 'j'))
ans1, ans2, ans3 = shard_map(foo, mesh, in_specs=(),
out_specs=P('i', 'j'))()
expected1 = jnp.arange(4.)[:, None] + jnp.zeros((4, 2))
expected2 = jnp.arange(2.)[None, :] + jnp.zeros((4, 2))
expected3 = jnp.arange(8.).reshape(4, 2)
self.assertAllClose(ans1, expected1, check_dtypes=False)
self.assertAllClose(ans2, expected2, check_dtypes=False)
self.assertAllClose(ans3, expected3, check_dtypes=False)

def test_axis_index_eager(self):
mesh = jtu.create_global_mesh((4,), ('x',))

@partial(shard_map, mesh=mesh, in_specs=(), out_specs=P())
def foo():
val = jax.lax.psum(jax.lax.axis_index('x'), 'x')
return 1. if val > 0 else -1.

out = foo() # doesn't crash
self.assertEqual(out, 1.)

def test_jaxpr_shardings_with_no_outputs(self):
# https://github.com/google/jax/issues/15385
Expand Down

0 comments on commit b48254e

Please sign in to comment.