diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6e7adfc60a53..87dfe2ce776e 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -360,10 +360,6 @@ def lower_jaxpr_to_module( assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims - if len(grid_mapping.grid) > 3: - raise NotImplementedError( - "Only <=3D grids are supported in Mosaic GPU lowering." - ) if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "Dynamic grid bounds not supported in the Mosaic GPU lowering." @@ -397,16 +393,19 @@ def lower_jaxpr_to_module( f" {max_concurrent_steps=}, {delay_release=}" ) - block = (128, 1, 1) - grid = grid_mapping.grid if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count block = (128 * grid_mapping.grid[-1], 1, 1) - grid = grid[:-1] - - grid = [d for i, d in enumerate(grid) if i not in sequential_axes] - if len(grid) < 3: - grid += (1,) * (3 - len(grid)) + logical_grid = grid_mapping.grid[:-1] else: + block = (128, 1, 1) + logical_grid = grid_mapping.grid + + parallel_grid = [ + d for i, d in enumerate(logical_grid) if i not in sequential_axes + ] + if len(parallel_grid) < 3: + parallel_grid += (1,) * (3 - len(parallel_grid)) + elif len(parallel_grid) > 3: raise NotImplementedError( "Only <=3D grids are supported in Mosaic GPU lowering." ) @@ -500,7 +499,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): _program_id(next(parallel_count)) if axis not in sequential_axes else None - for axis in range(len(grid_mapping.grid)) + for axis in range(len(logical_grid)) ] def make_program_ids(step: ir.Value): @@ -788,7 +787,7 @@ def _(step, carry): prof_ctx = ProfilerContext(params["profile_dir"], prof_spec) module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel( body, - grid=grid, + grid=parallel_grid, cluster=(), block=block, in_shapes=in_structs_gmem, @@ -806,7 +805,9 @@ def _(step, carry): prof_spec=prof_spec, ) - return LoweringResult(module, grid, block, out_structs_gmem, prof_ctx) + return LoweringResult( + module, parallel_grid, block, out_structs_gmem, prof_ctx + ) mosaic_lowering_rules = {} diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index d9894051dc81..16b7f1f59c33 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -234,6 +234,57 @@ def batch(self, leading_rank: int) -> MemRefTransform: ) +@dataclasses.dataclass(frozen=True) +class CollapseLeadingIndicesTransform(MemRefTransform): + """Collapses leading indices into one.""" + strides: tuple[int, ...] + + @functools.cached_property + def common_stride(self) -> int: + return math.gcd(*self.strides) + + def apply(self, ref: ir.Value) -> ir.Value: + ref_ty = ir.MemRefType(ref.type) + strides, offset = ref_ty.get_strides_and_offset() + if offset == ir.ShapedType.get_dynamic_stride_or_offset(): + raise NotImplementedError("Dynamic offsets are not supported") + max_bound = sum( + (d - 1) * s // self.common_stride + for d, s in zip( + ref_ty.shape[: len(self.strides)], strides[: len(self.strides)] + ) + ) + 1 + new_shape = [max_bound, *ref_ty.shape[len(self.strides):]] + new_strides = [self.common_stride, *strides[len(self.strides):]] + new_layout = ir.StridedLayoutAttr.get(offset, new_strides) + new_ref_ty = ir.MemRefType.get( + new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space + ) + return memref.reinterpret_cast( + new_ref_ty, ref, [], [], [], + static_offsets=[offset], + static_sizes=new_shape, + static_strides=new_strides, + ) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + index = ir.IndexType.get() + flat_idx = c(0, index) + for i, s in zip(idx[:len(self.strides)], self.strides): + flat_idx = arith.addi( + flat_idx, arith.muli(i, c(s // self.common_stride, index)) + ) + return (flat_idx, *idx[len(self.strides):]) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + if any(s != 1 for s in shape[:len(self.strides)]): + raise ValueError("Expected leading indices to be squeezed") + return (1, *shape[len(self.strides):]) + + def batch(self, leading_rank: int) -> MemRefTransform: + raise NotImplementedError # Unused + + OnDeviceProfiler = profiler.OnDeviceProfiler @@ -397,6 +448,17 @@ def async_copy( or gmem_ref.owner.opview.OPERATION_NAME != expected_name ): raise ValueError("GMEM reference in async_copy must be a kernel argument") + gmem_ref_ty = ir.MemRefType(gmem_ref.type) + gmem_strides, _ = gmem_ref_ty.get_strides_and_offset() + if gmem_strides != utils.get_contiguous_strides(gmem_ref_ty.shape): + raise NotImplementedError( + "async_copy assumes the GMEM reference is contiguous" + ) + if any(s * element_bytewidth % 16 != 0 for s in gmem_strides[:-1]): + raise ValueError( + "async_copy requires all GMEM strides except the last one to be a" + " multiple of 16 bytes" + ) base_indices, slice_shape, is_squeezed = utils.parse_indices( gmem_slice, ir.MemRefType(gmem_ref.type).shape @@ -421,9 +483,25 @@ def async_copy( dyn_base_indices = t.transform_index(dyn_base_indices) slice_shape = t.transform_shape(slice_shape) + num_squeezed_dims = len(squeezed_dims) + if len(slice_shape) > 5: + # We can try to collapse all squeezed dims into one. + if len(slice_shape) - num_squeezed_dims + 1 > 5: + raise ValueError( + "Async copies only support striding up to 5 dimensions" + ) + collapse = CollapseLeadingIndicesTransform( + tuple(gmem_strides[d] for d in squeezed_dims) + ) + gmem_transform = (*gmem_transform, collapse) + dyn_base_indices = collapse.transform_index(dyn_base_indices) + slice_shape = collapse.transform_shape(slice_shape) + num_squeezed_dims = 1 + del squeezed_dims, sliced_dims # Those no longer make sense. + smem_ref_ty = ir.MemRefType(smem_ref.type) # We moved all squeezed dims to the front. - if slice_shape[len(squeezed_dims):] != tuple(smem_ref_ty.shape): + if slice_shape[num_squeezed_dims:] != tuple(smem_ref_ty.shape): raise ValueError( "Expected the SMEM reference to have the same shape as the" f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" @@ -437,7 +515,7 @@ def async_copy( dyn_base_indices = list(dyn_base_indices) slice_shape = list(slice_shape) - assert all(d == 1 for d in slice_shape[:len(squeezed_dims)]) + assert all(d == 1 for d in slice_shape[:num_squeezed_dims]) collective_size = 1 if collective is not None: if isinstance(collective, gpu.Dimension): @@ -446,14 +524,14 @@ def async_copy( if collective_size > 1: def partition_dim(dim: int, idx: ir.Value, num_chunks: int): # No need to partition squeezed dims. They don't even exist in smem_ref. - assert dim >= len(squeezed_dims) + assert dim >= num_squeezed_dims nonlocal smem_ref slice_shape[dim] //= num_chunks block_offset = arith.muli(idx, c(slice_shape[dim], index)) dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) smem_ref = utils.memref_slice( smem_ref, - (slice(None),) * (dim - len(squeezed_dims)) + (slice(None),) * (dim - num_squeezed_dims) + (utils.ds(block_offset, slice_shape[dim]),), ) stride = 1 @@ -508,9 +586,6 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): else contextlib.nullcontext ) - rank = len(slice_shape) - if rank > 5: # TODO: apaszke - Implement stride compression - raise ValueError("Async copies only support striding up to 5 dimensions") if max(slice_shape) > 256: raise ValueError( "Async copies only support copying <=256 elements along each" diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index c4ac7e625942..78db197c673d 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -61,25 +61,14 @@ def attention(q, k, v, config: TuningConfig): raise ValueError(f"{head_dim=} must be divisible by 64") if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]): raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}") - # Squash batch and sequence dimensions. - # This is required because CUDA grid/TMA descriptors have a limited number of - # slice dimensions. - # TODO(apaszke): Implement slice squashing for TMAs. - q = jnp.reshape(q, (batch_size * q_seq_len, num_q_heads, head_dim)) - k = jnp.reshape(k, (batch_size * kv_seq_len, num_kv_heads, head_dim)) - v = jnp.reshape(v, (batch_size * kv_seq_len, num_kv_heads, head_dim)) max_concurrent_steps = min( config.max_concurrent_steps, kv_seq_len // config.block_kv ) block_q, block_kv = config.block_q, config.block_kv - num_q_tiles, rem = divmod(q_seq_len, block_q * 2) - if rem: - raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") def kernel(q_ref, k_ref, v_ref, out_ref, scoped): - bidx = lax.div(lax.axis_index("bq"), num_q_tiles) - qidx = lax.rem(lax.axis_index("bq"), num_q_tiles) + batch = lax.axis_index("batch") smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped wg_idx = lax.axis_index("wg") qo_smem2, k_smem, v_smem = smem_buffers @@ -93,11 +82,11 @@ def perform_schedule_barrier(): def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] - q_seq_base = qidx * (2 * block_q) + wg_idx * block_q + bidx * q_seq_len + q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q q_head = lax.axis_index("heads") plgpu.copy_gmem_to_smem( - q_ref.at[pl.ds(q_seq_base, block_q), q_head], + q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], qo_smem, q_barriers.at[wg_idx], ) @@ -167,7 +156,7 @@ def _wait(): qo_smem[...] = acc.astype(dtype) plgpu.commit_smem() plgpu.copy_smem_to_gmem( - qo_smem, out_ref.at[pl.ds(q_seq_base, block_q), q_head], + qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], ) plgpu.wait_smem_to_gmem(0) @pl.when(wg_idx == 2) @@ -175,16 +164,14 @@ def _memory_wg(): plgpu.set_max_registers(40, action="decrease") kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head) for i in range(max_concurrent_steps): - start = i * block_kv + bidx * kv_seq_len - s = (pl.ds(start, block_kv), kv_head) + s = (batch, pl.ds(i * block_kv, block_kv), kv_head) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i]) def kv_loop(kv_step, _): tma_step = kv_step + max_concurrent_steps tma_slot = lax.rem(kv_step, max_concurrent_steps) - start = tma_step * block_kv + bidx * kv_seq_len - s = (pl.ds(start, block_kv), kv_head) + s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head) plgpu.barrier_wait(k_consumed_barrier) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot]) plgpu.barrier_wait(v_consumed_barrier) @@ -199,10 +186,13 @@ def kv_epilogue(i, _): def run(refs): q_ref, k_ref, v_ref, out_ref = refs + num_q_tiles, rem = divmod(q_seq_len, block_q * 2) + if rem: + raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}") mesh = plgpu.GPUMesh( - grid=(batch_size * num_q_tiles, num_q_heads), + grid=(batch_size, num_q_tiles, num_q_heads), num_threads=3, - axis_names=("bq", "heads", "wg"), + axis_names=("batch", "q_seq", "heads", "wg"), approx_math=True, ) @pl.core_map(mesh) @@ -236,7 +226,7 @@ def _kernel_entry(): ) _, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf))) - return jnp.reshape(out, [batch_size, q_seq_len, num_q_heads, head_dim]) + return out @jax.jit diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 71f2d383f809..39182841f190 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1240,7 +1240,7 @@ def run_kernel(shape): x = np.arange(np.prod(shape)).reshape(shape) _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) - with self.assertRaisesRegex(ValueError, "only support striding up to 5"): + with self.assertRaisesRegex(ValueError, "all GMEM strides except the last"): run_kernel([1] * 6) with self.assertRaisesRegex(