diff --git a/README.md b/README.md index 7cca4c3..dae250c 100644 --- a/README.md +++ b/README.md @@ -112,4 +112,4 @@ seq_len for query = 1, - FlaxAttention (Without Pallas Flash Attention): **0.00650s** - Jax Pallas Decoding Attention (no score_mod): 0.00998s -We can see that pure JAX implementation is actually the fastest, surpassing Palllas Flash Attention. The kernel also supports arbitrary query length and the inflection point is around 64, where the Palllas Flash Attention starts to outperform the pure JAX implementation when the query length is greater than 64. \ No newline at end of file +We can see that pure JAX implementation is actually the fastest, surpassing Palllas Flash Attention. The kernel also supports arbitrary query length and the inflection point is around 64, where the Palllas Flash Attention starts to outperform the pure JAX implementation when the query length is greater than 64. (For autograd, the inflection point is around 1024, which is quite bad). \ No newline at end of file diff --git a/flaxattention/kernel/attention.py b/flaxattention/kernel/attention.py index a050109..88b6d74 100644 --- a/flaxattention/kernel/attention.py +++ b/flaxattention/kernel/attention.py @@ -203,14 +203,14 @@ def mha( score_mod_grad: _score_mod_signature | None = None, ): del backward_pass_impl - batch_size, seq_len_q, num_heads, head_dim = q.shape - seq_len_kv = k.shape[1] - block_q = min(block_q, seq_len_q) - block_k = min(block_k, seq_len_kv) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) # Heuristics. grid_ = grid if grid_ is None: - grid_ = (pl.cdiv(seq_len_q, block_q), batch_size, num_heads) # seq, batch, head + grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) # seq, batch, head num_warps_ = num_warps if num_warps_ is None: @@ -230,13 +230,13 @@ def mha( in_specs = [ pl.BlockSpec((None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)), - pl.BlockSpec((None, seq_len_kv, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, seq_len_kv, None, head_dim), lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None - else pl.BlockSpec((None, seq_len_kv), lambda _, j, k: (j, 0)) + else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) return pl.pallas_call( @@ -276,13 +276,14 @@ def _mha_forward( score_mod_grad: _score_mod_signature | None, ): del backward_pass_impl, score_mod_grad - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) # Heuristics. grid_ = grid if grid_ is None: - grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) + grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) # seq, batch, head num_warps_ = num_warps if num_warps_ is None: @@ -302,19 +303,19 @@ def _mha_forward( out_shape = [ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out jax.ShapeDtypeStruct( - shape=(batch_size, num_heads, seq_len), + shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32, # lse ), ] in_specs = [ pl.BlockSpec((None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)), - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None - else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) + else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) out, lse = pl.pallas_call( kernel, @@ -394,7 +395,8 @@ def mha_backward_kernel( score_mod_grad: _score_mod_signature | None = None, ): del out_ref # Not needed - seq_len = q_ref.shape[0] + q_seq_len = q_ref.shape[0] + kv_seq_len = k_ref.shape[0] # Scan #1: dK and dV # 1. Load a block of K and V of size (block_k1, head_dim) in SMEM. @@ -470,7 +472,7 @@ def inner_loop_dkdv(start_q, carry): lower_bound = lax.div(start_k * block_k1, block_q1) if causal else 0 dv, dk = lax.fori_loop( - lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk) + lower_bound, pl.cdiv(q_seq_len, block_q1), inner_loop_dkdv, (dv, dk) ) dv_ref[...] = dv.astype(dv_ref.dtype) dk_ref[...] = dk.astype(dk_ref.dtype) @@ -549,7 +551,7 @@ def inner_loop_dq(start_k, dq): if causal: upper_bound = lax.div((start_q + 1) * block_q2, block_k2) else: - upper_bound = pl.cdiv(seq_len, block_k2) + upper_bound = pl.cdiv(kv_seq_len, block_k2) dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) dq_ref[...] = dq.astype(dq_ref.dtype) @@ -584,31 +586,31 @@ def _mha_backward( segment_ids, )[1](do) elif backward_pass_impl == "triton": - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) delta = _preprocess_backward(out, do, lse, block_q, debug, interpret) out_shapes = [ jax.ShapeDtypeStruct(q.shape, q.dtype), jax.ShapeDtypeStruct(k.shape, k.dtype), jax.ShapeDtypeStruct(v.shape, v.dtype), ] - in_specs = [ - pl.BlockSpec((None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), - pl.BlockSpec((None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), - pl.BlockSpec((None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), - pl.BlockSpec((None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), - pl.BlockSpec((None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), - pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), - pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), + pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] else: - in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda i, j, _: (i, 0))) + in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0))) - grid = (batch_size, num_heads, pl.cdiv(seq_len, block_k)) + grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_k)) num_warps = 4 if head_dim <= 64 else 8 dq, dk, dv = pl.pallas_call(