From b801539f5c9a3857ce0d274f8ca61f5c4259b5ee Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 28 Nov 2024 08:34:19 -0800 Subject: [PATCH] [Pallas][Mosaic GPU] Add support for compressing squeezed dims in async_copy + grid fixes This change removes the need to flatten the batch dimension into sequence dimensions in the flash attention kernel. The critical thing here is the observation that we can in fact collapse all squeezed dimension into a single one in the TMA descriptor, letting us reduce its rank when necessary. Doing this also uncovered some issues with how we were handling the grid in Pallas:MGPU lowering, which I've fixed. PiperOrigin-RevId: 701035277 --- jax/_src/pallas/mosaic_gpu/lowering.py | 29 +++--- jax/experimental/mosaic/gpu/core.py | 89 +++++++++++++++++-- .../pallas/ops/gpu/attention_mgpu.py | 34 +++---- tests/mosaic/gpu_test.py | 2 +- 4 files changed, 110 insertions(+), 44 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6e7adfc60a53..87dfe2ce776e 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -360,10 +360,6 @@ def lower_jaxpr_to_module( assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims - if len(grid_mapping.grid) > 3: - raise NotImplementedError( - "Only <=3D grids are supported in Mosaic GPU lowering." - ) if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "Dynamic grid bounds not supported in the Mosaic GPU lowering." @@ -397,16 +393,19 @@ def lower_jaxpr_to_module( f" {max_concurrent_steps=}, {delay_release=}" ) - block = (128, 1, 1) - grid = grid_mapping.grid if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count block = (128 * grid_mapping.grid[-1], 1, 1) - grid = grid[:-1] - - grid = [d for i, d in enumerate(grid) if i not in sequential_axes] - if len(grid) < 3: - grid += (1,) * (3 - len(grid)) + logical_grid = grid_mapping.grid[:-1] else: + block = (128, 1, 1) + logical_grid = grid_mapping.grid + + parallel_grid = [ + d for i, d in enumerate(logical_grid) if i not in sequential_axes + ] + if len(parallel_grid) < 3: + parallel_grid += (1,) * (3 - len(parallel_grid)) + elif len(parallel_grid) > 3: raise NotImplementedError( "Only <=3D grids are supported in Mosaic GPU lowering." ) @@ -500,7 +499,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): _program_id(next(parallel_count)) if axis not in sequential_axes else None - for axis in range(len(grid_mapping.grid)) + for axis in range(len(logical_grid)) ] def make_program_ids(step: ir.Value): @@ -788,7 +787,7 @@ def _(step, carry): prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel( body, - grid=grid, + grid=parallel_grid, cluster=(), block=block, in_shapes=in_structs_gmem, @@ -806,7 +805,9 @@ def _(step, carry): prof_spec=prof_spec, ) - return LoweringResult(module, grid, block, out_structs_gmem, prof_ctx) + return LoweringResult( + module, parallel_grid, block, out_structs_gmem, prof_ctx + ) mosaic_lowering_rules = {} diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index d9894051dc81..16b7f1f59c33 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -234,6 +234,57 @@ def batch(self, leading_rank: int) -> MemRefTransform: ) +@dataclasses.dataclass(frozen=True) +class CollapseLeadingIndicesTransform(MemRefTransform): + """Collapses leading indices into one.""" + strides: tuple[int, ...] + + @functools.cached_property + def common_stride(self) -> int: + return math.gcd(*self.strides) + + def apply(self, ref: ir.Value) -> ir.Value: + ref_ty = ir.MemRefType(ref.type) + strides, offset = ref_ty.get_strides_and_offset() + if offset == ir.ShapedType.get_dynamic_stride_or_offset(): + raise NotImplementedError("Dynamic offsets are not supported") + max_bound = sum( + (d - 1) * s // self.common_stride + for d, s in zip( + ref_ty.shape[: len(self.strides)], strides[: len(self.strides)] + ) + ) + 1 + new_shape = [max_bound, *ref_ty.shape[len(self.strides):]] + new_strides = [self.common_stride, *strides[len(self.strides):]] + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + new_ref_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + return memref.reinterpret_cast( + new_ref_ty, ref, [], [], [], + static_offsets=[offset], + static_sizes=new_shape, + static_strides=new_strides, + ) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + index = ir.IndexType.get() + flat_idx = c(0, index) + for i, s in zip(idx[:len(self.strides)], self.strides): + flat_idx = arith.addi( + flat_idx, arith.muli(i, c(s // self.common_stride, index)) + ) + return (flat_idx, *idx[len(self.strides):]) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + if any(s != 1 for s in shape[:len(self.strides)]): + raise ValueError("Expected leading indices to be squeezed") + return (1, *shape[len(self.strides):]) + + def batch(self, leading_rank: int) -> MemRefTransform: + raise NotImplementedError # Unused + + OnDeviceProfiler = profiler.OnDeviceProfiler @@ -397,6 +448,17 @@ def async_copy( or gmem_ref.owner.opview.OPERATION_NAME != expected_name ): raise ValueError("GMEM reference in async_copy must be a kernel argument") + gmem_ref_ty = ir.MemRefType(gmem_ref.type) + gmem_strides, _ = gmem_ref_ty.get_strides_and_offset() + if gmem_strides != utils.get_contiguous_strides(gmem_ref_ty.shape): + raise NotImplementedError( + "async_copy assumes the GMEM reference is contiguous" + ) + if any(s * element_bytewidth % 16 != 0 for s in gmem_strides[:-1]): + raise ValueError( + "async_copy requires all GMEM strides except the last one to be a" + " multiple of 16 bytes" + ) base_indices, slice_shape, is_squeezed = utils.parse_indices( gmem_slice, ir.MemRefType(gmem_ref.type).shape @@ -421,9 +483,25 @@ def async_copy( dyn_base_indices = t.transform_index(dyn_base_indices) slice_shape = t.transform_shape(slice_shape) + num_squeezed_dims = len(squeezed_dims) + if len(slice_shape) > 5: + # We can try to collapse all squeezed dims into one. + if len(slice_shape) - num_squeezed_dims + 1 > 5: + raise ValueError( + "Async copies only support striding up to 5 dimensions" + ) + collapse = CollapseLeadingIndicesTransform( + tuple(gmem_strides[d] for d in squeezed_dims) + ) + gmem_transform = (*gmem_transform, collapse) + dyn_base_indices = collapse.transform_index(dyn_base_indices) + slice_shape = collapse.transform_shape(slice_shape) + num_squeezed_dims = 1 + del squeezed_dims, sliced_dims # Those no longer make sense. + smem_ref_ty = ir.MemRefType(smem_ref.type) # We moved all squeezed dims to the front. - if slice_shape[len(squeezed_dims):] != tuple(smem_ref_ty.shape): + if slice_shape[num_squeezed_dims:] != tuple(smem_ref_ty.shape): raise ValueError( "Expected the SMEM reference to have the same shape as the" f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" @@ -437,7 +515,7 @@ def async_copy( dyn_base_indices = list(dyn_base_indices) slice_shape = list(slice_shape) - assert all(d == 1 for d in slice_shape[:len(squeezed_dims)]) + assert all(d == 1 for d in slice_shape[:num_squeezed_dims]) collective_size = 1 if collective is not None: if isinstance(collective, gpu.Dimension): @@ -446,14 +524,14 @@ def async_copy( if collective_size > 1: def partition_dim(dim: int, idx: ir.Value, num_chunks: int): # No need to partition squeezed dims. They don't even exist in smem_ref. - assert dim >= len(squeezed_dims) + assert dim >= num_squeezed_dims nonlocal smem_ref slice_shape[dim] //= num_chunks block_offset = arith.muli(idx, c(slice_shape[dim], index)) dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) smem_ref = utils.memref_slice( smem_ref, - (slice(None),) * (dim - len(squeezed_dims)) + (slice(None),) * (dim - num_squeezed_dims) + (utils.ds(block_offset, slice_shape[dim]),), ) stride = 1 @@ -508,9 +586,6 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): else contextlib.nullcontext ) - rank = len(slice_shape) - if rank > 5: # TODO: apaszke - Implement stride compression - raise ValueError("Async copies only support striding up to 5 dimensions") if max(slice_shape) > 256: raise ValueError( "Async copies only support copying <=256 elements along each" diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index c4ac7e625942..78db197c673d 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -61,25 +61,14 @@ def attention(q, k, v, config: TuningConfig): raise ValueError(f"{head_dim=} must be divisible by 64") if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]): raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}") - # Squash batch and sequence dimensions. - # This is required because CUDA grid/TMA descriptors have a limited number of - # slice dimensions. - # TODO(apaszke): Implement slice squashing for TMAs. - q = jnp.reshape(q, (batch_size * q_seq_len, num_q_heads, head_dim)) - k = jnp.reshape(k, (batch_size * kv_seq_len, num_kv_heads, head_dim)) - v = jnp.reshape(v, (batch_size * kv_seq_len, num_kv_heads, head_dim)) max_concurrent_steps = min( config.max_concurrent_steps, kv_seq_len // config.block_kv ) block_q, block_kv = config.block_q, config.block_kv - num_q_tiles, rem = divmod(q_seq_len, block_q * 2) - if rem: - raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") def kernel(q_ref, k_ref, v_ref, out_ref, scoped): - bidx = lax.div(lax.axis_index("bq"), num_q_tiles) - qidx = lax.rem(lax.axis_index("bq"), num_q_tiles) + batch = lax.axis_index("batch") smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped wg_idx = lax.axis_index("wg") qo_smem2, k_smem, v_smem = smem_buffers @@ -93,11 +82,11 @@ def perform_schedule_barrier(): def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] - q_seq_base = qidx * (2 * block_q) + wg_idx * block_q + bidx * q_seq_len + q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q q_head = lax.axis_index("heads") plgpu.copy_gmem_to_smem( - q_ref.at[pl.ds(q_seq_base, block_q), q_head], + q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], qo_smem, q_barriers.at[wg_idx], ) @@ -167,7 +156,7 @@ def _wait(): qo_smem[...] = acc.astype(dtype) plgpu.commit_smem() plgpu.copy_smem_to_gmem( - qo_smem, out_ref.at[pl.ds(q_seq_base, block_q), q_head], + qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) plgpu.wait_smem_to_gmem(0) @pl.when(wg_idx == 2) @@ -175,16 +164,14 @@ def _memory_wg(): plgpu.set_max_registers(40, action="decrease") kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head) for i in range(max_concurrent_steps): - start = i * block_kv + bidx * kv_seq_len - s = (pl.ds(start, block_kv), kv_head) + s = (batch, pl.ds(i * block_kv, block_kv), kv_head) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i]) def kv_loop(kv_step, _): tma_step = kv_step + max_concurrent_steps tma_slot = lax.rem(kv_step, max_concurrent_steps) - start = tma_step * block_kv + bidx * kv_seq_len - s = (pl.ds(start, block_kv), kv_head) + s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head) plgpu.barrier_wait(k_consumed_barrier) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barrier) @@ -199,10 +186,13 @@ def kv_epilogue(i, _): def run(refs): q_ref, k_ref, v_ref, out_ref = refs + num_q_tiles, rem = divmod(q_seq_len, block_q * 2) + if rem: + raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") mesh = plgpu.GPUMesh( - grid=(batch_size * num_q_tiles, num_q_heads), + grid=(batch_size, num_q_tiles, num_q_heads), num_threads=3, - axis_names=("bq", "heads", "wg"), + axis_names=("batch", "q_seq", "heads", "wg"), approx_math=True, ) @pl.core_map(mesh) @@ -236,7 +226,7 @@ def _kernel_entry(): ) _, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf))) - return jnp.reshape(out, [batch_size, q_seq_len, num_q_heads, head_dim]) + return out @jax.jit diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 71f2d383f809..39182841f190 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1240,7 +1240,7 @@ def run_kernel(shape): x = np.arange(np.prod(shape)).reshape(shape) _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) - with self.assertRaisesRegex(ValueError, "only support striding up to 5"): + with self.assertRaisesRegex(ValueError, "all GMEM strides except the last"): run_kernel([1] * 6) with self.assertRaisesRegex(