From a31129a869b5cb0ae4d91b0f0f4038ec50a02556 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 5 Dec 2023 00:09:34 -0800 Subject: [PATCH] [Pallas/TPU] Add all gather kernel example PiperOrigin-RevId: 587963496 --- jax/_src/pallas/mosaic/core.py | 7 + jax/_src/pallas/mosaic/primitives.py | 11 +- jax/experimental/pallas/ops/tpu/all_gather.py | 155 ++++++++++++++++++ tests/pallas/BUILD | 14 ++ tests/pallas/all_gather_test.py | 123 ++++++++++++++ 5 files changed, 307 insertions(+), 3 deletions(-) create mode 100644 jax/experimental/pallas/ops/tpu/all_gather.py create mode 100644 tests/pallas/all_gather_test.py diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index a5fade94f61e..8b2e835fd468 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -102,6 +102,13 @@ def _ref_raise_to_shaped(ref_aval: AbstractMemoryRef, weak_type): class AbstractSemaphore(jax_core.AbstractValue): sem_type: SemaphoreType + def join(self, other): + if not isinstance(other, AbstractSemaphore): + raise ValueError + if other.sem_type != self.sem_type: + raise ValueError + return self + jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index f7fb5108470d..f00a18f66806 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -118,9 +118,14 @@ class DeviceIdType(enum.Enum): semaphore_signal_p = jax_core.Primitive('semaphore_signal') semaphore_signal_p.multiple_results = True -def semaphore_signal(sem, inc: int | jax.Array = 1, - *, device_id: int | jax.Array | None = None, - device_id_type: DeviceIdType = DeviceIdType.MESH): + +def semaphore_signal( + sem, + inc: int | jax.Array = 1, + *, + device_id: int | jax.Array | None | tuple[int | jax.Array, ...] = None, + device_id_type: DeviceIdType = DeviceIdType.MESH, +): inc = jnp.asarray(inc, dtype=jnp.int32) args = [sem, inc] has_device_id = device_id is not None diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py new file mode 100644 index 000000000000..828f8d7c139d --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -0,0 +1,155 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple all-gather kernel. + +This is meant to be a pedagogical example of how to write a custom collective +using Pallas. It doesn't have all possible performance optimizations and doesn't +currently handle more diverse topologies. + +The kernel assumes a ring structure on a single mesh axis. It takes the local +chunk, splits it in two, and sends each of the half-chunks in each direction +(left and right) until every device has received the half chunks. +""" +from __future__ import annotations +import functools + +from typing import Sequence + +import jax +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + + +P = jax.sharding.PartitionSpec + + +def get_neighbor( + idx: jax.Array, mesh: jax.sharding.Mesh, axis_name: str, *, direction: str +) -> tuple[jax.Array, ...]: + """Helper function that computes the mesh indices of a neighbor.""" + axis_names = mesh.axis_names + which_axis = axis_names.index(axis_name) + mesh_index = [ + idx if i == which_axis else lax.axis_index(a) + for i, a in enumerate(axis_names) + ] + axis_size = lax.psum(1, axis_name) + if direction == "right": + next_idx = lax.rem(idx + 1, axis_size) + else: + left = idx - 1 + next_idx = jnp.where(left < 0, left + axis_size, left) + mesh_index[which_axis] = next_idx + return tuple(mesh_index) + + +def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str, + mesh: jax.sharding.Mesh): + my_id = lax.axis_index(axis_name) + # TODO(sharadmv): could speed this up having the first remote DMA go from + # x_ref->o_ref immediately instead of a blocking HBM copy. + with pltpu.trace("initial_copy"): + pltpu.async_copy(x_ref, o_ref.at[my_id], recv_sem[0]).wait() + + with pltpu.trace("neighbour_lookup"): + axis_size = lax.psum(1, axis_name) + left_neighbor = get_neighbor(my_id, mesh, axis_name, direction="left") + right_neighbor = get_neighbor(my_id, mesh, axis_name, direction="right") + + with pltpu.trace("main_barrier"): + sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(sem, 2, device_id=left_neighbor) + pltpu.semaphore_signal(sem, 2, device_id=right_neighbor) + pltpu.semaphore_wait(sem, 2) + + shard_size = x_ref.shape[0] + right_dma, left_dma = None, None + # Main strategy for this AG: carve up our input into two slices. Send + # each slice along each direction until they reach every device. + for i in range(axis_size - 1): + right_slot = my_id - i + right_slice = pl.ds(shard_size // 2, shard_size // 2) + slot = jnp.where(right_slot < 0, axis_size + right_slot, right_slot) + if right_dma: + with pltpu.trace("wait_right_dma"): + right_dma.wait() + right_dma = pltpu.async_remote_copy( + o_ref.at[slot, right_slice], + o_ref.at[slot, right_slice], + send_sem[1], + recv_sem[1], + device_id=right_neighbor, + ) + + left_slot = my_id + i + left_slice = pl.ds(0, shard_size // 2) + slot = lax.rem(left_slot, axis_size) + if left_dma: + with pltpu.trace("wait_left_dma"): + left_dma.wait() + left_dma = pltpu.async_remote_copy( + o_ref.at[slot, left_slice], + o_ref.at[slot, left_slice], + send_sem[0], + recv_sem[0], + device_id=left_neighbor, + ) + with pltpu.trace("wait_all_dma"): + assert right_dma is not None + assert left_dma is not None + right_dma.wait() + left_dma.wait() + + +@functools.partial( + jax.jit, static_argnames=["mesh", "axis_name", "memory_space"] +) +def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str], + memory_space: pltpu.TPUMemorySpace = pltpu.VMEM): + if isinstance(axis_name, str): + axis_name = (axis_name,) + # TODO(sharadmv): enable all gather over multiple axes + if len(axis_name) > 1: + raise NotImplementedError("Only one axis supported.") + axis_name, = axis_name + if mesh.shape[axis_name] == 1: + # We can short-circuit here if our axis size is 1 + return x + def ag_local(x_shard): + axis_size = lax.psum(1, axis_name) + out_shape = jax.ShapeDtypeStruct((axis_size, *x_shard.shape), x_shard.dtype) + out = pl.pallas_call( + functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh), + out_shape=out_shape, + mosaic_params=dict(collective_id=0), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + scratch_shapes=( + (pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA), + (pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA), + ), + in_specs=[pl.BlockSpec(memory_space=memory_space)], + out_specs=pl.BlockSpec(memory_space=memory_space), + ), + )(x_shard) + return out.reshape((axis_size * x_shard.shape[0], *x_shard.shape[1:])) + + return shard_map.shard_map( + ag_local, mesh=mesh, in_specs=P(axis_name), out_specs=P(None), + check_rep=False + )(x) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 0a10099ecebd..368689f0a159 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -64,3 +64,17 @@ py_test( "//third_party/py/jax:pallas", ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) + +jax_test( + name = "all_gather_test", + srcs = [ + "all_gather_test.py", + ], + disable_backends = [ + "cpu", + "gpu", + ], + deps = [ + "//third_party/py/jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), +) diff --git a/tests/pallas/all_gather_test.py b/tests/pallas/all_gather_test.py new file mode 100644 index 000000000000..ae34d92a3ca4 --- /dev/null +++ b/tests/pallas/all_gather_test.py @@ -0,0 +1,123 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests the simple all_gather kernel.""" +from __future__ import annotations + +from absl.testing import absltest +import hypothesis as hp +import hypothesis.strategies as hps +import jax +from jax import random +from jax._src import test_util as jtu +from jax.experimental import mesh_utils +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu import all_gather +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() + +P = jax.sharding.PartitionSpec +hp.settings.register_profile( + "deterministic", + database=None, + derandomize=True, + deadline=None, + max_examples=50, + print_blob=True, + verbosity=hp.Verbosity.verbose, +) +hp.settings.load_profile("deterministic") + + +@hps.composite +def _array_shapes(draw): + # TODO(sharadmv, apaszke): enable this on a wider variety of shapes + valid_shapes = [ + (128, 128), + (256, 128), + (256, 512), + (256, 1024), + # TODO(sharadmv,apaszke): enable these shapes + # (256, 129), + # (129, 128), + # (64, 64), + # (1, 1), + ] + return draw(hps.sampled_from(valid_shapes)) + + +@hps.composite +def _array_dtypes(draw): + return draw( + hps.sampled_from([ + jnp.float32, + jnp.bfloat16, + jnp.int32, + # jnp.float16, # TODO(sharadmv,apaszke): enable float16 all gather + # jnp.int16, # TODO(sharadmv,apaszke): enable int16 all gather + # jnp.int8, # TODO(sharadmv,apaszke): enable int8 all gather + ]) + ) + + +class AllGatherTest(jtu.JaxTestCase): + + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes()) + def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): + if jax.device_count() < 2: + self.skipTest("Need more devices") + memory_space = pltpu.VMEM if is_vmem else pltpu.ANY + mesh_shape = (jax.device_count(),) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] + ) + leading, *rest = shape + shape = (mesh.shape["x"] * leading, *rest) + x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) + x_sharded = jax.device_put(x, jax.sharding.NamedSharding(mesh, P("x"))) + y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name="x", + memory_space=memory_space) + np.testing.assert_array_equal(y, x) + + @hp.given(hps.booleans(), _array_shapes(), _array_dtypes(), + hps.sampled_from(["x", "y"])) + def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, + axis_name): + if jax.device_count() < 2: + self.skipTest("Need more devices") + if jax.device_count() % 2: + self.skipTest("Need an even number of devices") + memory_space = pltpu.VMEM if is_vmem else pltpu.ANY + mesh_shape = (2, jax.device_count() // 2) + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] + ) + if axis_name == "x": + sharding = jax.sharding.NamedSharding(mesh, P("x", None)) + else: + sharding = jax.sharding.NamedSharding(mesh, P("y", None)) + leading, *rest = shape + shape = (mesh.shape[axis_name] * leading, *rest) + x = random.normal(random.key(0), shape, dtype=jnp.float32).astype(dtype) + x_sharded = jax.device_put(x, sharding) + y = all_gather.all_gather(x_sharded, mesh=mesh, axis_name=axis_name, + memory_space=memory_space) + np.testing.assert_array_equal(y, x) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())