Skip to content

Commit

Permalink
[Pallas/TPU] Add all gather kernel example
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587963496
  • Loading branch information
sharadmv authored and jax authors committed Dec 5, 2023
1 parent 91ef37b commit a31129a
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 3 deletions.
7 changes: 7 additions & 0 deletions jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type):
class AbstractSemaphore(jax_core.AbstractValue):
sem_type: SemaphoreType

def join(self, other):
if not isinstance(other, AbstractSemaphore):
raise ValueError
if other.sem_type != self.sem_type:
raise ValueError
return self

jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval


Expand Down
11 changes: 8 additions & 3 deletions jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,14 @@ class DeviceIdType(enum.Enum):
semaphore_signal_p = jax_core.Primitive('semaphore_signal')
semaphore_signal_p.multiple_results = True

def semaphore_signal(sem, inc: int | jax.Array = 1,
*, device_id: int | jax.Array | None = None,
device_id_type: DeviceIdType = DeviceIdType.MESH):

def semaphore_signal(
sem,
inc: int | jax.Array = 1,
*,
device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None,
device_id_type: DeviceIdType = DeviceIdType.MESH,
):
inc = jnp.asarray(inc, dtype=jnp.int32)
args = [sem, inc]
has_device_id = device_id is not None
Expand Down
155 changes: 155 additions & 0 deletions jax/experimental/pallas/ops/tpu/all_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Simple all-gather kernel.
This is meant to be a pedagogical example of how to write a custom collective
using Pallas. It doesn't have all possible performance optimizations and doesn't
currently handle more diverse topologies.
The kernel assumes a ring structure on a single mesh axis. It takes the local
chunk, splits it in two, and sends each of the half-chunks in each direction
(left and right) until every device has received the half chunks.
"""
from __future__ import annotations
import functools

from typing import Sequence

import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp


P = jax.sharding.PartitionSpec


def get_neighbor(
idx: jax.Array, mesh: jax.sharding.Mesh, axis_name: str, *, direction: str
) -> tuple[jax.Array, ...]:
"""Helper function that computes the mesh indices of a neighbor."""
axis_names = mesh.axis_names
which_axis = axis_names.index(axis_name)
mesh_index = [
idx if i == which_axis else lax.axis_index(a)
for i, a in enumerate(axis_names)
]
axis_size = lax.psum(1, axis_name)
if direction == "right":
next_idx = lax.rem(idx + 1, axis_size)
else:
left = idx - 1
next_idx = jnp.where(left < 0, left + axis_size, left)
mesh_index[which_axis] = next_idx
return tuple(mesh_index)


def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str,
mesh: jax.sharding.Mesh):
my_id = lax.axis_index(axis_name)
# TODO(sharadmv): could speed this up having the first remote DMA go from
# x_ref->o_ref immediately instead of a blocking HBM copy.
with pltpu.trace("initial_copy"):
pltpu.async_copy(x_ref, o_ref.at[my_id], recv_sem[0]).wait()

with pltpu.trace("neighbour_lookup"):
axis_size = lax.psum(1, axis_name)
left_neighbor = get_neighbor(my_id, mesh, axis_name, direction="left")
right_neighbor = get_neighbor(my_id, mesh, axis_name, direction="right")

with pltpu.trace("main_barrier"):
sem = pltpu.get_barrier_semaphore()
pltpu.semaphore_signal(sem, 2, device_id=left_neighbor)
pltpu.semaphore_signal(sem, 2, device_id=right_neighbor)
pltpu.semaphore_wait(sem, 2)

shard_size = x_ref.shape[0]
right_dma, left_dma = None, None
# Main strategy for this AG: carve up our input into two slices. Send
# each slice along each direction until they reach every device.
for i in range(axis_size - 1):
right_slot = my_id - i
right_slice = pl.ds(shard_size // 2, shard_size // 2)
slot = jnp.where(right_slot < 0, axis_size + right_slot, right_slot)
if right_dma:
with pltpu.trace("wait_right_dma"):
right_dma.wait()
right_dma = pltpu.async_remote_copy(
o_ref.at[slot, right_slice],
o_ref.at[slot, right_slice],
send_sem[1],
recv_sem[1],
device_id=right_neighbor,
)

left_slot = my_id + i
left_slice = pl.ds(0, shard_size // 2)
slot = lax.rem(left_slot, axis_size)
if left_dma:
with pltpu.trace("wait_left_dma"):
left_dma.wait()
left_dma = pltpu.async_remote_copy(
o_ref.at[slot, left_slice],
o_ref.at[slot, left_slice],
send_sem[0],
recv_sem[0],
device_id=left_neighbor,
)
with pltpu.trace("wait_all_dma"):
assert right_dma is not None
assert left_dma is not None
right_dma.wait()
left_dma.wait()


@functools.partial(
jax.jit, static_argnames=["mesh", "axis_name", "memory_space"]
)
def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str],
memory_space: pltpu.TPUMemorySpace = pltpu.VMEM):
if isinstance(axis_name, str):
axis_name = (axis_name,)
# TODO(sharadmv): enable all gather over multiple axes
if len(axis_name) > 1:
raise NotImplementedError("Only one axis supported.")
axis_name, = axis_name
if mesh.shape[axis_name] == 1:
# We can short-circuit here if our axis size is 1
return x
def ag_local(x_shard):
axis_size = lax.psum(1, axis_name)
out_shape = jax.ShapeDtypeStruct((axis_size, *x_shard.shape), x_shard.dtype)
out = pl.pallas_call(
functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh),
out_shape=out_shape,
mosaic_params=dict(collective_id=0),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
scratch_shapes=(
(pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA),
(pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA),
),
in_specs=[pl.BlockSpec(memory_space=memory_space)],
out_specs=pl.BlockSpec(memory_space=memory_space),
),
)(x_shard)
return out.reshape((axis_size * x_shard.shape[0], *x_shard.shape[1:]))

return shard_map.shard_map(
ag_local, mesh=mesh, in_specs=P(axis_name), out_specs=P(None),
check_rep=False
)(x)
14 changes: 14 additions & 0 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,17 @@ py_test(
"//third_party/py/jax:pallas",
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
)

jax_test(
name = "all_gather_test",
srcs = [
"all_gather_test.py",
],
disable_backends = [
"cpu",
"gpu",
],
deps = [
"//third_party/py/jax:pallas_tpu_ops",
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
)
123 changes: 123 additions & 0 deletions tests/pallas/all_gather_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests the simple all_gather kernel."""
from __future__ import annotations

from absl.testing import absltest
import hypothesis as hp
import hypothesis.strategies as hps
import jax
from jax import random
from jax._src import test_util as jtu
from jax.experimental import mesh_utils
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas.ops.tpu import all_gather
import jax.numpy as jnp
import numpy as np


jax.config.parse_flags_with_absl()

P = jax.sharding.PartitionSpec
hp.settings.register_profile(
"deterministic",
database=None,
derandomize=True,
deadline=None,
max_examples=50,
print_blob=True,
verbosity=hp.Verbosity.verbose,
)
hp.settings.load_profile("deterministic")


@hps.composite
def _array_shapes(draw):
# TODO(sharadmv, apaszke): enable this on a wider variety of shapes
valid_shapes = [
(128, 128),
(256, 128),
(256, 512),
(256, 1024),
# TODO(sharadmv,apaszke): enable these shapes
# (256, 129),
# (129, 128),
# (64, 64),
# (1, 1),
]
return draw(hps.sampled_from(valid_shapes))


@hps.composite
def _array_dtypes(draw):
return draw(
hps.sampled_from([
jnp.float32,
jnp.bfloat16,
jnp.int32,
# jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather
# jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather
# jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather
])
)


class AllGatherTest(jtu.JaxTestCase):

@hp.given(hps.booleans(), _array_shapes(), _array_dtypes())
def test_all_gather_1d_mesh(self, is_vmem, shape, dtype):
if jax.device_count() < 2:
self.skipTest("Need more devices")
memory_space = pltpu.VMEM if is_vmem else pltpu.ANY
mesh_shape = (jax.device_count(),)
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"]
)
leading, *rest = shape
shape = (mesh.shape["x"] * leading, *rest)
x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype)
x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x")))
y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x",
memory_space=memory_space)
np.testing.assert_array_equal(y, x)

@hp.given(hps.booleans(), _array_shapes(), _array_dtypes(),
hps.sampled_from(["x", "y"]))
def test_all_gather_2d_mesh(self, is_vmem, shape, dtype,
axis_name):
if jax.device_count() < 2:
self.skipTest("Need more devices")
if jax.device_count() % 2:
self.skipTest("Need an even number of devices")
memory_space = pltpu.VMEM if is_vmem else pltpu.ANY
mesh_shape = (2, jax.device_count() // 2)
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"]
)
if axis_name == "x":
sharding = jax.sharding.NamedSharding(mesh, P("x", None))
else:
sharding = jax.sharding.NamedSharding(mesh, P("y", None))
leading, *rest = shape
shape = (mesh.shape[axis_name] * leading, *rest)
x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype)
x_sharded = jax.device_put(x, sharding)
y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name,
memory_space=memory_space)
np.testing.assert_array_equal(y, x)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit a31129a

Please sign in to comment.