-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Pallas/TPU] Add all gather kernel example
PiperOrigin-RevId: 587963496
- Loading branch information
Showing
5 changed files
with
307 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright 2023 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Simple all-gather kernel. | ||
This is meant to be a pedagogical example of how to write a custom collective | ||
using Pallas. It doesn't have all possible performance optimizations and doesn't | ||
currently handle more diverse topologies. | ||
The kernel assumes a ring structure on a single mesh axis. It takes the local | ||
chunk, splits it in two, and sends each of the half-chunks in each direction | ||
(left and right) until every device has received the half chunks. | ||
""" | ||
from __future__ import annotations | ||
import functools | ||
|
||
from typing import Sequence | ||
|
||
import jax | ||
from jax import lax | ||
from jax.experimental import pallas as pl | ||
from jax.experimental import shard_map | ||
from jax.experimental.pallas import tpu as pltpu | ||
import jax.numpy as jnp | ||
|
||
|
||
P = jax.sharding.PartitionSpec | ||
|
||
|
||
def get_neighbor( | ||
idx: jax.Array, mesh: jax.sharding.Mesh, axis_name: str, *, direction: str | ||
) -> tuple[jax.Array, ...]: | ||
"""Helper function that computes the mesh indices of a neighbor.""" | ||
axis_names = mesh.axis_names | ||
which_axis = axis_names.index(axis_name) | ||
mesh_index = [ | ||
idx if i == which_axis else lax.axis_index(a) | ||
for i, a in enumerate(axis_names) | ||
] | ||
axis_size = lax.psum(1, axis_name) | ||
if direction == "right": | ||
next_idx = lax.rem(idx + 1, axis_size) | ||
else: | ||
left = idx - 1 | ||
next_idx = jnp.where(left < 0, left + axis_size, left) | ||
mesh_index[which_axis] = next_idx | ||
return tuple(mesh_index) | ||
|
||
|
||
def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, | ||
mesh: jax.sharding.Mesh): | ||
my_id = lax.axis_index(axis_name) | ||
# TODO(sharadmv): could speed this up having the first remote DMA go from | ||
# x_ref->o_ref immediately instead of a blocking HBM copy. | ||
with pltpu.trace("initial_copy"): | ||
pltpu.async_copy(x_ref, o_ref.at[my_id], recv_sem[0]).wait() | ||
|
||
with pltpu.trace("neighbour_lookup"): | ||
axis_size = lax.psum(1, axis_name) | ||
left_neighbor = get_neighbor(my_id, mesh, axis_name, direction="left") | ||
right_neighbor = get_neighbor(my_id, mesh, axis_name, direction="right") | ||
|
||
with pltpu.trace("main_barrier"): | ||
sem = pltpu.get_barrier_semaphore() | ||
pltpu.semaphore_signal(sem, 2, device_id=left_neighbor) | ||
pltpu.semaphore_signal(sem, 2, device_id=right_neighbor) | ||
pltpu.semaphore_wait(sem, 2) | ||
|
||
shard_size = x_ref.shape[0] | ||
right_dma, left_dma = None, None | ||
# Main strategy for this AG: carve up our input into two slices. Send | ||
# each slice along each direction until they reach every device. | ||
for i in range(axis_size - 1): | ||
right_slot = my_id - i | ||
right_slice = pl.ds(shard_size // 2, shard_size // 2) | ||
slot = jnp.where(right_slot < 0, axis_size + right_slot, right_slot) | ||
if right_dma: | ||
with pltpu.trace("wait_right_dma"): | ||
right_dma.wait() | ||
right_dma = pltpu.async_remote_copy( | ||
o_ref.at[slot, right_slice], | ||
o_ref.at[slot, right_slice], | ||
send_sem[1], | ||
recv_sem[1], | ||
device_id=right_neighbor, | ||
) | ||
|
||
left_slot = my_id + i | ||
left_slice = pl.ds(0, shard_size // 2) | ||
slot = lax.rem(left_slot, axis_size) | ||
if left_dma: | ||
with pltpu.trace("wait_left_dma"): | ||
left_dma.wait() | ||
left_dma = pltpu.async_remote_copy( | ||
o_ref.at[slot, left_slice], | ||
o_ref.at[slot, left_slice], | ||
send_sem[0], | ||
recv_sem[0], | ||
device_id=left_neighbor, | ||
) | ||
with pltpu.trace("wait_all_dma"): | ||
assert right_dma is not None | ||
assert left_dma is not None | ||
right_dma.wait() | ||
left_dma.wait() | ||
|
||
|
||
@functools.partial( | ||
jax.jit, static_argnames=["mesh", "axis_name", "memory_space"] | ||
) | ||
def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str], | ||
memory_space: pltpu.TPUMemorySpace = pltpu.VMEM): | ||
if isinstance(axis_name, str): | ||
axis_name = (axis_name,) | ||
# TODO(sharadmv): enable all gather over multiple axes | ||
if len(axis_name) > 1: | ||
raise NotImplementedError("Only one axis supported.") | ||
axis_name, = axis_name | ||
if mesh.shape[axis_name] == 1: | ||
# We can short-circuit here if our axis size is 1 | ||
return x | ||
def ag_local(x_shard): | ||
axis_size = lax.psum(1, axis_name) | ||
out_shape = jax.ShapeDtypeStruct((axis_size, *x_shard.shape), x_shard.dtype) | ||
out = pl.pallas_call( | ||
functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh), | ||
out_shape=out_shape, | ||
mosaic_params=dict(collective_id=0), | ||
grid_spec=pltpu.PrefetchScalarGridSpec( | ||
num_scalar_prefetch=0, | ||
scratch_shapes=( | ||
(pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA), | ||
(pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA), | ||
), | ||
in_specs=[pl.BlockSpec(memory_space=memory_space)], | ||
out_specs=pl.BlockSpec(memory_space=memory_space), | ||
), | ||
)(x_shard) | ||
return out.reshape((axis_size * x_shard.shape[0], *x_shard.shape[1:])) | ||
|
||
return shard_map.shard_map( | ||
ag_local, mesh=mesh, in_specs=P(axis_name), out_specs=P(None), | ||
check_rep=False | ||
)(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright 2023 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests the simple all_gather kernel.""" | ||
from __future__ import annotations | ||
|
||
from absl.testing import absltest | ||
import hypothesis as hp | ||
import hypothesis.strategies as hps | ||
import jax | ||
from jax import random | ||
from jax._src import test_util as jtu | ||
from jax.experimental import mesh_utils | ||
from jax.experimental.pallas import tpu as pltpu | ||
from jax.experimental.pallas.ops.tpu import all_gather | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
|
||
jax.config.parse_flags_with_absl() | ||
|
||
P = jax.sharding.PartitionSpec | ||
hp.settings.register_profile( | ||
"deterministic", | ||
database=None, | ||
derandomize=True, | ||
deadline=None, | ||
max_examples=50, | ||
print_blob=True, | ||
verbosity=hp.Verbosity.verbose, | ||
) | ||
hp.settings.load_profile("deterministic") | ||
|
||
|
||
@hps.composite | ||
def _array_shapes(draw): | ||
# TODO(sharadmv, apaszke): enable this on a wider variety of shapes | ||
valid_shapes = [ | ||
(128, 128), | ||
(256, 128), | ||
(256, 512), | ||
(256, 1024), | ||
# TODO(sharadmv,apaszke): enable these shapes | ||
# (256, 129), | ||
# (129, 128), | ||
# (64, 64), | ||
# (1, 1), | ||
] | ||
return draw(hps.sampled_from(valid_shapes)) | ||
|
||
|
||
@hps.composite | ||
def _array_dtypes(draw): | ||
return draw( | ||
hps.sampled_from([ | ||
jnp.float32, | ||
jnp.bfloat16, | ||
jnp.int32, | ||
# jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather | ||
# jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather | ||
# jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather | ||
]) | ||
) | ||
|
||
|
||
class AllGatherTest(jtu.JaxTestCase): | ||
|
||
@hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) | ||
def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): | ||
if jax.device_count() < 2: | ||
self.skipTest("Need more devices") | ||
memory_space = pltpu.VMEM if is_vmem else pltpu.ANY | ||
mesh_shape = (jax.device_count(),) | ||
mesh = jax.sharding.Mesh( | ||
mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] | ||
) | ||
leading, *rest = shape | ||
shape = (mesh.shape["x"] * leading, *rest) | ||
x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) | ||
x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x"))) | ||
y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x", | ||
memory_space=memory_space) | ||
np.testing.assert_array_equal(y, x) | ||
|
||
@hp.given(hps.booleans(), _array_shapes(), _array_dtypes(), | ||
hps.sampled_from(["x", "y"])) | ||
def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, | ||
axis_name): | ||
if jax.device_count() < 2: | ||
self.skipTest("Need more devices") | ||
if jax.device_count() % 2: | ||
self.skipTest("Need an even number of devices") | ||
memory_space = pltpu.VMEM if is_vmem else pltpu.ANY | ||
mesh_shape = (2, jax.device_count() // 2) | ||
mesh = jax.sharding.Mesh( | ||
mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] | ||
) | ||
if axis_name == "x": | ||
sharding = jax.sharding.NamedSharding(mesh, P("x", None)) | ||
else: | ||
sharding = jax.sharding.NamedSharding(mesh, P("y", None)) | ||
leading, *rest = shape | ||
shape = (mesh.shape[axis_name] * leading, *rest) | ||
x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) | ||
x_sharded = jax.device_put(x, sharding) | ||
y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name, | ||
memory_space=memory_space) | ||
np.testing.assert_array_equal(y, x) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main(testLoader=jtu.JaxTestLoader()) |