Skip to content

Commit

Permalink
fix autograd for len_q != len_kv
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 24, 2024
1 parent a6860a4 commit eb229b9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 33 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,4 @@ seq_len for query = 1,
- FlaxAttention (Without Pallas Flash Attention): **0.00650s**
- Jax Pallas Decoding Attention (no score_mod): 0.00998s

We can see that pure JAX implementation is actually the fastest, surpassing Palllas Flash Attention. The kernel also supports arbitrary query length and the inflection point is around 64, where the Palllas Flash Attention starts to outperform the pure JAX implementation when the query length is greater than 64.
We can see that pure JAX implementation is actually the fastest, surpassing Palllas Flash Attention. The kernel also supports arbitrary query length and the inflection point is around 64, where the Palllas Flash Attention starts to outperform the pure JAX implementation when the query length is greater than 64. (For autograd, the inflection point is around 1024, which is quite bad).
66 changes: 34 additions & 32 deletions flaxattention/kernel/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ def mha(
score_mod_grad: _score_mod_signature | None = None,
):
del backward_pass_impl
batch_size, seq_len_q, num_heads, head_dim = q.shape
seq_len_kv = k.shape[1]
block_q = min(block_q, seq_len_q)
block_k = min(block_k, seq_len_kv)
batch_size, q_seq_len, num_heads, head_dim = q.shape
kv_seq_len = k.shape[1]
block_q = min(block_q, q_seq_len)
block_k = min(block_k, kv_seq_len)
# Heuristics.
grid_ = grid
if grid_ is None:
grid_ = (pl.cdiv(seq_len_q, block_q), batch_size, num_heads) # seq, batch, head
grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) # seq, batch, head

num_warps_ = num_warps
if num_warps_ is None:
Expand All @@ -230,13 +230,13 @@ def mha(

in_specs = [
pl.BlockSpec((None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)),
pl.BlockSpec((None, seq_len_kv, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, seq_len_kv, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
]
in_specs.append(
None # type: ignore[arg-type]
if segment_ids is None
else pl.BlockSpec((None, seq_len_kv), lambda _, j, k: (j, 0))
else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0))
)
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
return pl.pallas_call(
Expand Down Expand Up @@ -276,13 +276,14 @@ def _mha_forward(
score_mod_grad: _score_mod_signature | None,
):
del backward_pass_impl, score_mod_grad
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
batch_size, q_seq_len, num_heads, head_dim = q.shape
kv_seq_len = k.shape[1]
block_q = min(block_q, q_seq_len)
block_k = min(block_k, kv_seq_len)
# Heuristics.
grid_ = grid
if grid_ is None:
grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) # seq, batch, head

num_warps_ = num_warps
if num_warps_ is None:
Expand All @@ -302,19 +303,19 @@ def _mha_forward(
out_shape = [
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
jax.ShapeDtypeStruct(
shape=(batch_size, num_heads, seq_len),
shape=(batch_size, num_heads, q_seq_len),
dtype=jnp.float32, # lse
),
]
in_specs = [
pl.BlockSpec((None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)),
pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
]
in_specs.append(
None # type: ignore[arg-type]
if segment_ids is None
else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0))
else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0))
)
out, lse = pl.pallas_call(
kernel,
Expand Down Expand Up @@ -394,7 +395,8 @@ def mha_backward_kernel(
score_mod_grad: _score_mod_signature | None = None,
):
del out_ref # Not needed
seq_len = q_ref.shape[0]
q_seq_len = q_ref.shape[0]
kv_seq_len = k_ref.shape[0]

# Scan #1: dK and dV
# 1. Load a block of K and V of size (block_k1, head_dim) in SMEM.
Expand Down Expand Up @@ -470,7 +472,7 @@ def inner_loop_dkdv(start_q, carry):

lower_bound = lax.div(start_k * block_k1, block_q1) if causal else 0
dv, dk = lax.fori_loop(
lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk)
lower_bound, pl.cdiv(q_seq_len, block_q1), inner_loop_dkdv, (dv, dk)
)
dv_ref[...] = dv.astype(dv_ref.dtype)
dk_ref[...] = dk.astype(dk_ref.dtype)
Expand Down Expand Up @@ -549,7 +551,7 @@ def inner_loop_dq(start_k, dq):
if causal:
upper_bound = lax.div((start_q + 1) * block_q2, block_k2)
else:
upper_bound = pl.cdiv(seq_len, block_k2)
upper_bound = pl.cdiv(kv_seq_len, block_k2)

dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq))
dq_ref[...] = dq.astype(dq_ref.dtype)
Expand Down Expand Up @@ -584,31 +586,31 @@ def _mha_backward(
segment_ids,
)[1](do)
elif backward_pass_impl == "triton":
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
batch_size, q_seq_len, num_heads, head_dim = q.shape
kv_seq_len = k.shape[1]
block_q = min(block_q, q_seq_len)
block_k = min(block_k, kv_seq_len)
delta = _preprocess_backward(out, do, lse, block_q, debug, interpret)
out_shapes = [
jax.ShapeDtypeStruct(q.shape, q.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]

in_specs = [
pl.BlockSpec((None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)),
]
if segment_ids is None:
in_specs.insert(3, None) # type: ignore[arg-type]
else:
in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda i, j, _: (i, 0)))
in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0)))

grid = (batch_size, num_heads, pl.cdiv(seq_len, block_k))
grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_k))
num_warps = 4 if head_dim <= 64 else 8

dq, dk, dv = pl.pallas_call(
Expand Down

0 comments on commit eb229b9

Please sign in to comment.