Skip to content

Commit

Permalink
[Pallas][Mosaic GPU] Add support for compressing squeezed dims in asy…
Browse files Browse the repository at this point in the history
…nc_copy + grid fixes

This change removes the need to flatten the batch dimension into sequence dimensions
in the flash attention kernel. The critical thing here is the observation that we can
in fact collapse all squeezed dimension into a single one in the TMA descriptor, letting
us reduce its rank when necessary.

Doing this also uncovered some issues with how we were handling the grid in Pallas:MGPU
lowering, which I've fixed.

PiperOrigin-RevId: 701035277
  • Loading branch information
apaszke authored and Google-ML-Automation committed Nov 28, 2024
1 parent d5bfafb commit b801539
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 44 deletions.
29 changes: 15 additions & 14 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,6 @@ def lower_jaxpr_to_module(

assert len(jaxpr.outvars) == 0
assert not grid_mapping.vmapped_dims
if len(grid_mapping.grid) > 3:
raise NotImplementedError(
"Only <=3D grids are supported in Mosaic GPU lowering."
)
if grid_mapping.num_dynamic_grid_bounds:
raise NotImplementedError(
"Dynamic grid bounds not supported in the Mosaic GPU lowering."
Expand Down Expand Up @@ -397,16 +393,19 @@ def lower_jaxpr_to_module(
f" {max_concurrent_steps=}, {delay_release=}"
)

block = (128, 1, 1)
grid = grid_mapping.grid
if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count
block = (128 * grid_mapping.grid[-1], 1, 1)
grid = grid[:-1]

grid = [d for i, d in enumerate(grid) if i not in sequential_axes]
if len(grid) < 3:
grid += (1,) * (3 - len(grid))
logical_grid = grid_mapping.grid[:-1]
else:
block = (128, 1, 1)
logical_grid = grid_mapping.grid

parallel_grid = [
d for i, d in enumerate(logical_grid) if i not in sequential_axes
]
if len(parallel_grid) < 3:
parallel_grid += (1,) * (3 - len(parallel_grid))
elif len(parallel_grid) > 3:
raise NotImplementedError(
"Only <=3D grids are supported in Mosaic GPU lowering."
)
Expand Down Expand Up @@ -500,7 +499,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
_program_id(next(parallel_count))
if axis not in sequential_axes
else None
for axis in range(len(grid_mapping.grid))
for axis in range(len(logical_grid))
]

def make_program_ids(step: ir.Value):
Expand Down Expand Up @@ -788,7 +787,7 @@ def _(step, carry):
prof_ctx = ProfilerContext(params["profile_dir"], prof_spec)
module, out_structs_gmem, _ = mgpu_core._lower_as_gpu_kernel(
body,
grid=grid,
grid=parallel_grid,
cluster=(),
block=block,
in_shapes=in_structs_gmem,
Expand All @@ -806,7 +805,9 @@ def _(step, carry):
prof_spec=prof_spec,
)

return LoweringResult(module, grid, block, out_structs_gmem, prof_ctx)
return LoweringResult(
module, parallel_grid, block, out_structs_gmem, prof_ctx
)


mosaic_lowering_rules = {}
Expand Down
89 changes: 82 additions & 7 deletions jax/experimental/mosaic/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,57 @@ def batch(self, leading_rank: int) -> MemRefTransform:
)


@dataclasses.dataclass(frozen=True)
class CollapseLeadingIndicesTransform(MemRefTransform):
"""Collapses leading indices into one."""
strides: tuple[int, ...]

@functools.cached_property
def common_stride(self) -> int:
return math.gcd(*self.strides)

def apply(self, ref: ir.Value) -> ir.Value:
ref_ty = ir.MemRefType(ref.type)
strides, offset = ref_ty.get_strides_and_offset()
if offset == ir.ShapedType.get_dynamic_stride_or_offset():
raise NotImplementedError("Dynamic offsets are not supported")
max_bound = sum(
(d - 1) * s // self.common_stride
for d, s in zip(
ref_ty.shape[: len(self.strides)], strides[: len(self.strides)]
)
) + 1
new_shape = [max_bound, *ref_ty.shape[len(self.strides):]]
new_strides = [self.common_stride, *strides[len(self.strides):]]
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
new_ref_ty = ir.MemRefType.get(
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
)
return memref.reinterpret_cast(
new_ref_ty, ref, [], [], [],
static_offsets=[offset],
static_sizes=new_shape,
static_strides=new_strides,
)

def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
index = ir.IndexType.get()
flat_idx = c(0, index)
for i, s in zip(idx[:len(self.strides)], self.strides):
flat_idx = arith.addi(
flat_idx, arith.muli(i, c(s // self.common_stride, index))
)
return (flat_idx, *idx[len(self.strides):])

def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
if any(s != 1 for s in shape[:len(self.strides)]):
raise ValueError("Expected leading indices to be squeezed")
return (1, *shape[len(self.strides):])

def batch(self, leading_rank: int) -> MemRefTransform:
raise NotImplementedError # Unused


OnDeviceProfiler = profiler.OnDeviceProfiler


Expand Down Expand Up @@ -397,6 +448,17 @@ def async_copy(
or gmem_ref.owner.opview.OPERATION_NAME != expected_name
):
raise ValueError("GMEM reference in async_copy must be a kernel argument")
gmem_ref_ty = ir.MemRefType(gmem_ref.type)
gmem_strides, _ = gmem_ref_ty.get_strides_and_offset()
if gmem_strides != utils.get_contiguous_strides(gmem_ref_ty.shape):
raise NotImplementedError(
"async_copy assumes the GMEM reference is contiguous"
)
if any(s * element_bytewidth % 16 != 0 for s in gmem_strides[:-1]):
raise ValueError(
"async_copy requires all GMEM strides except the last one to be a"
" multiple of 16 bytes"
)

base_indices, slice_shape, is_squeezed = utils.parse_indices(
gmem_slice, ir.MemRefType(gmem_ref.type).shape
Expand All @@ -421,9 +483,25 @@ def async_copy(
dyn_base_indices = t.transform_index(dyn_base_indices)
slice_shape = t.transform_shape(slice_shape)

num_squeezed_dims = len(squeezed_dims)
if len(slice_shape) > 5:
# We can try to collapse all squeezed dims into one.
if len(slice_shape) - num_squeezed_dims + 1 > 5:
raise ValueError(
"Async copies only support striding up to 5 dimensions"
)
collapse = CollapseLeadingIndicesTransform(
tuple(gmem_strides[d] for d in squeezed_dims)
)
gmem_transform = (*gmem_transform, collapse)
dyn_base_indices = collapse.transform_index(dyn_base_indices)
slice_shape = collapse.transform_shape(slice_shape)
num_squeezed_dims = 1
del squeezed_dims, sliced_dims # Those no longer make sense.

smem_ref_ty = ir.MemRefType(smem_ref.type)
# We moved all squeezed dims to the front.
if slice_shape[len(squeezed_dims):] != tuple(smem_ref_ty.shape):
if slice_shape[num_squeezed_dims:] != tuple(smem_ref_ty.shape):
raise ValueError(
"Expected the SMEM reference to have the same shape as the"
f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
Expand All @@ -437,7 +515,7 @@ def async_copy(

dyn_base_indices = list(dyn_base_indices)
slice_shape = list(slice_shape)
assert all(d == 1 for d in slice_shape[:len(squeezed_dims)])
assert all(d == 1 for d in slice_shape[:num_squeezed_dims])
collective_size = 1
if collective is not None:
if isinstance(collective, gpu.Dimension):
Expand All @@ -446,14 +524,14 @@ def async_copy(
if collective_size > 1:
def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
# No need to partition squeezed dims. They don't even exist in smem_ref.
assert dim >= len(squeezed_dims)
assert dim >= num_squeezed_dims
nonlocal smem_ref
slice_shape[dim] //= num_chunks
block_offset = arith.muli(idx, c(slice_shape[dim], index))
dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset)
smem_ref = utils.memref_slice(
smem_ref,
(slice(None),) * (dim - len(squeezed_dims))
(slice(None),) * (dim - num_squeezed_dims)
+ (utils.ds(block_offset, slice_shape[dim]),),
)
stride = 1
Expand Down Expand Up @@ -508,9 +586,6 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
else contextlib.nullcontext
)

rank = len(slice_shape)
if rank > 5: # TODO: apaszke - Implement stride compression
raise ValueError("Async copies only support striding up to 5 dimensions")
if max(slice_shape) > 256:
raise ValueError(
"Async copies only support copying <=256 elements along each"
Expand Down
34 changes: 12 additions & 22 deletions jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,14 @@ def attention(q, k, v, config: TuningConfig):
raise ValueError(f"{head_dim=} must be divisible by 64")
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}")
# Squash batch and sequence dimensions.
# This is required because CUDA grid/TMA descriptors have a limited number of
# slice dimensions.
# TODO(apaszke): Implement slice squashing for TMAs.
q = jnp.reshape(q, (batch_size * q_seq_len, num_q_heads, head_dim))
k = jnp.reshape(k, (batch_size * kv_seq_len, num_kv_heads, head_dim))
v = jnp.reshape(v, (batch_size * kv_seq_len, num_kv_heads, head_dim))

max_concurrent_steps = min(
config.max_concurrent_steps, kv_seq_len // config.block_kv
)
block_q, block_kv = config.block_q, config.block_kv
num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
if rem:
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")

def kernel(q_ref, k_ref, v_ref, out_ref, scoped):
bidx = lax.div(lax.axis_index("bq"), num_q_tiles)
qidx = lax.rem(lax.axis_index("bq"), num_q_tiles)
batch = lax.axis_index("batch")
smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped
wg_idx = lax.axis_index("wg")
qo_smem2, k_smem, v_smem = smem_buffers
Expand All @@ -93,11 +82,11 @@ def perform_schedule_barrier():
def _compute_wg():
plgpu.set_max_registers(232, action="increase")
qo_smem = qo_smem2.at[wg_idx]
q_seq_base = qidx * (2 * block_q) + wg_idx * block_q + bidx * q_seq_len
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")

plgpu.copy_gmem_to_smem(
q_ref.at[pl.ds(q_seq_base, block_q), q_head],
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
qo_smem,
q_barriers.at[wg_idx],
)
Expand Down Expand Up @@ -167,24 +156,22 @@ def _wait():
qo_smem[...] = acc.astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
qo_smem, out_ref.at[pl.ds(q_seq_base, block_q), q_head],
qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
)
plgpu.wait_smem_to_gmem(0)
@pl.when(wg_idx == 2)
def _memory_wg():
plgpu.set_max_registers(40, action="decrease")
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
for i in range(max_concurrent_steps):
start = i * block_kv + bidx * kv_seq_len
s = (pl.ds(start, block_kv), kv_head)
s = (batch, pl.ds(i * block_kv, block_kv), kv_head)
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i])
plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i])

def kv_loop(kv_step, _):
tma_step = kv_step + max_concurrent_steps
tma_slot = lax.rem(kv_step, max_concurrent_steps)
start = tma_step * block_kv + bidx * kv_seq_len
s = (pl.ds(start, block_kv), kv_head)
s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head)
plgpu.barrier_wait(k_consumed_barrier)
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot])
plgpu.barrier_wait(v_consumed_barrier)
Expand All @@ -199,10 +186,13 @@ def kv_epilogue(i, _):
def run(refs):
q_ref, k_ref, v_ref, out_ref = refs

num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
if rem:
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")
mesh = plgpu.GPUMesh(
grid=(batch_size * num_q_tiles, num_q_heads),
grid=(batch_size, num_q_tiles, num_q_heads),
num_threads=3,
axis_names=("bq", "heads", "wg"),
axis_names=("batch", "q_seq", "heads", "wg"),
approx_math=True,
)
@pl.core_map(mesh)
Expand Down Expand Up @@ -236,7 +226,7 @@ def _kernel_entry():
)

_, _, _, out = pl.run_state(run)((q, k, v, jnp.full_like(q, jnp.inf)))
return jnp.reshape(out, [batch_size, q_seq_len, num_q_heads, head_dim])
return out


@jax.jit
Expand Down
2 changes: 1 addition & 1 deletion tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ def run_kernel(shape):
x = np.arange(np.prod(shape)).reshape(shape)
_ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x)

with self.assertRaisesRegex(ValueError, "only support striding up to 5"):
with self.assertRaisesRegex(ValueError, "all GMEM strides except the last"):
run_kernel([1] * 6)

with self.assertRaisesRegex(
Expand Down

0 comments on commit b801539

Please sign in to comment.