Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention #931

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions alphafold/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,12 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'use_remat': False,
'zero_init': True,
'eval_dropout': False,
'use_flash_attention': False,
'flash': {
'num_warps': 2,
'block_q': 64,
'block_k': 32
}
},
'heads': {
'distogram': {
Expand Down Expand Up @@ -618,6 +624,12 @@ def model_config(name: str) -> ml_collections.ConfigDict:
'use_remat': False,
'zero_init': True,
'eval_dropout': False,
'use_flash_attention': False,
'flash': {
'num_warps': 2,
'block_q': 64,
'block_k': 32
}
},
'heads': {
'distogram': {
Expand Down
251 changes: 230 additions & 21 deletions alphafold/model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import haiku as hk
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
import jax.lax as lax


_SOFTMAX_MASK = -1e9
Expand Down Expand Up @@ -558,8 +560,8 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
m_data: A tensor of memories from which the keys and values are
projected, shape [batch_size, N_keys, m_channels].
mask: A mask for the attention, shape [batch_size, N_queries, N_keys].
nonbatched_bias: Shared bias, shape [N_queries, N_keys].
mask: A mask for the attention, shape [batch_size or 1, N_heads or 1, N_queries or 1, N_keys].
nonbatched_bias: Shared bias, shape [N_heads, N_queries, N_keys].

Returns:
A float32 tensor of shape [batch_size, N_queries, output_dim].
Expand All @@ -573,6 +575,7 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
key_dim = key_dim // num_head
value_dim = value_dim // num_head

# weight loading
q_weights = hk.get_parameter(
'query_w', shape=(q_data.shape[-1], num_head, key_dim),
dtype=q_data.dtype,
Expand All @@ -586,16 +589,6 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
dtype=q_data.dtype,
init=glorot_uniform())

q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)
if nonbatched_bias is not None:
logits += jnp.expand_dims(nonbatched_bias, axis=0)
logits = jnp.where(mask, logits, _SOFTMAX_MASK)
weights = utils.stable_softmax(logits)
weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)

if self.global_config.zero_init:
init = hk.initializers.Constant(0.0)
else:
Expand All @@ -613,13 +606,6 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
dtype=q_data.dtype,
init=hk.initializers.Constant(1.0))

gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
gating_weights) + gating_bias

gate_values = jax.nn.sigmoid(gate_values)

weighted_avg *= gate_values

o_weights = hk.get_parameter(
'output_w', shape=(num_head, value_dim, self.output_dim),
dtype=q_data.dtype,
Expand All @@ -628,12 +614,236 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
'output_b', shape=(self.output_dim,),
dtype=q_data.dtype,
init=hk.initializers.Constant(0.0))

# pre-kernel
if self.config.gating:
gate_values = jnp.einsum('bqc, chv->bqhv', q_data,
gating_weights) + gating_bias
gate_values = jax.nn.sigmoid(gate_values)
else:
gate_values = None
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)

# kernel
q_len, k_len = q.shape[1], k.shape[1]
kernel = functools.partial(self.flash_kernel, **self.global_config.flash) if (
self.global_config.use_flash_attention and
q_len >= self.global_config.flash.block_q and
k_len >= self.global_config.flash.block_k) else self.reference_kernel
weighted_avg = kernel(q, k, v, mask=mask, nonbatched_bias=nonbatched_bias, gate_values=gate_values)

# post-kernel
output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias

return output


@staticmethod
def reference_kernel(q, k, v, mask, nonbatched_bias=None, gate_values=None):
# kernel
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)
if nonbatched_bias is not None:
logits += jnp.expand_dims(nonbatched_bias, axis=0)
logits = jnp.where(mask, logits, _SOFTMAX_MASK)
weights = utils.stable_softmax(logits)
weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
if gate_values is not None:
weighted_avg *= gate_values
return weighted_avg

@staticmethod
def flash_pallas_kernel(
# Input arrays
q_ref: jax.Array,
k_ref: jax.Array,
v_ref: jax.Array,
gate_values_ref: jax.Array | None,
mask_ref: jax.Array | None,
logit_bias_ref: jax.Array | None,
# Output arrays
o_ref: jax.Array,
L_ref: jax.Array | None,
# Options
block_q: int,
block_k: int,
):
# convenience functions to match syntax of FlashAttention2 forward pass (Algorithm 1)
_log = jnp.log
_exp = jnp.exp
_maximum = jnp.maximum
_rowmax = lambda x: x.max(axis=-1)
_rowsum = lambda x: x.sum(axis=-1)
_diag = lambda x: x[:, None]
_dot = lambda a, b: pl.dot(a.astype(b.dtype), b) # matches dtype of second element, in dot(p,v) this pushes to use bfloat16
_generate_1d_bounds_mask = lambda ref, start, block_size: lax.iota(jnp.int32, block_size)+start*block_size < ref.shape[0] if (ref.shape[0]%block_size!=0) else None
# When head_dim < 16 pl.dot is not supported so we pad up
_head_dim_mask = lambda ref: jax.lax.iota(jnp.int32, 16)<ref.shape[-1]
def _combine_masks(*masks: jax.Array | None):
# Combines a set of masks matching the dimension of the tensor to store. None for an all-True array
dims = set(range(len(masks)))
masks = [jnp.expand_dims(mask, tuple(dims-set({i}))) for i, mask in enumerate(masks) if mask is not None]
if (len(masks)==0): return None; # all None case - don't use a mask
mask = functools.reduce(lambda x, y: x & y, masks)
return mask
def _load_mask_kwargs(masks: tuple[jax.Array | None, ...], other=0.):
mask = _combine_masks(*masks)
return {'mask':mask, 'other':other} if (mask is not None) else {}
def _padding_load(ref, slice, dim0_mask=None):
return pl.load(ref, (slice, pl.dslice(None)), **_load_mask_kwargs((dim0_mask, None))) if (ref.shape[-1]>=16) else \
pl.load(ref, (slice, pl.dslice(0,16)), **_load_mask_kwargs((dim0_mask, _head_dim_mask(ref))))
def _padding_store(ref, slice, val, dim0_mask=None):
pl.store(ref, (slice, pl.dslice(None)), val, mask=_combine_masks(dim0_mask, None)) if (ref.shape[-1]>=16) else \
pl.store(ref, (slice, pl.dslice(0,16)), val, mask=_combine_masks(dim0_mask, _head_dim_mask(ref)))

# m and l are updated during the kv loop.
# o is the buffer where we accumulate the output on sram.
m_init = jnp.zeros(block_q, dtype=jnp.float32) - 1e9
l_init = jnp.zeros(block_q, dtype=jnp.float32)
o_init = jnp.zeros((block_q, max(v_ref.shape[-1], 16)), dtype=jnp.float32)

# Grid loops over q
start_q = pl.program_id(0)
q_slice = pl.dslice(start_q * block_q, block_q)
q_bounds_mask = _generate_1d_bounds_mask(q_ref, start_q, block_q)
q = _padding_load(q_ref, q_slice, q_bounds_mask)

# Here we only loop over blocks of kv to process entire seq_len, the loop over
# blocks of q is carried out by the grid.
def body(start_k, carry):
o_prev, m_prev, l_prev = carry
k_slice = pl.dslice(start_k * block_k, block_k)
kv_bounds_mask = _generate_1d_bounds_mask(k_ref, start_k, block_k)

k = _padding_load(k_ref, k_slice, kv_bounds_mask)
v = _padding_load(v_ref, k_slice, kv_bounds_mask)
qk = pl.dot(q, k.T) # (block_q, block_k)

if (logit_bias_ref is not None):
logit_bias = pl.load(logit_bias_ref, (q_slice, k_slice),
**_load_mask_kwargs((q_bounds_mask, kv_bounds_mask)))
qk += logit_bias

if (mask_ref is not None):
kv_only_mask = (mask_ref.shape[0]==1)
mask = pl.load(mask_ref, (pl.dslice(0,1) if kv_only_mask else q_slice, k_slice),
**_load_mask_kwargs((None if kv_only_mask else q_bounds_mask, kv_bounds_mask)))
qk = jnp.where(mask, qk, _SOFTMAX_MASK)

# boundary checks
bounds_mask = _combine_masks(q_bounds_mask, kv_bounds_mask)
if (bounds_mask is not None):
qk = jnp.where(bounds_mask, qk, _SOFTMAX_MASK)

# Mapping of indexing to FlashAttention2 papers notation:
# x_{i}^{j-1} = x_prev
# x_{i}^{j} = x
s = qk
m = _maximum(m_prev, _rowmax(s))
p = _exp(s-_diag(m)) # _diag not present in paper, but is needed
l = _exp(m_prev-m)*l_prev + _rowsum(p)
o = _diag(_exp(m_prev-m)) * o_prev + _dot(p,v)
return o, m, l

kv_len = k_ref.shape[0]
n_blocks = pl.cdiv(kv_len, block_k)
o, m, l = lax.fori_loop(0, n_blocks, body, (o_init, m_init, l_init))

if (gate_values_ref is not None):
gate_values = _padding_load(gate_values_ref, q_slice, q_bounds_mask)
o *= gate_values

o *= _diag(1./l)

if (L_ref is not None):
L = m + _log(l)
pl.store(L_ref, (q_slice,), L, mask=q_bounds_mask)
# Write output to dram.
o = o.astype(o_ref.dtype)
_padding_store(o_ref, q_slice, o, q_bounds_mask)

@classmethod
def flash_kernel(
cls,
q,
k,
v,
gate_values: jnp.ndarray | None = None,
mask: jnp.ndarray | None = None,
nonbatched_bias: jnp.ndarray | None = None,
return_residual: bool = False,
block_q: int = 32,
block_k: int = 32,
num_warps: int | None = 2,
num_stages: int = 2,
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
):
if (mask is not None):
mask = mask.astype(jnp.bool_)
batch_size, q_len, num_heads, qk_head_dim = q.shape
_, kv_len, _, _ = k.shape
_, _, _, v_head_dim = v.shape

block_q = min(block_q, q_len)
block_k = min(block_k, kv_len)

if grid is None:
grid = (pl.cdiv(q_len, block_q), batch_size, num_heads)

kernel = functools.partial(cls.flash_pallas_kernel,
block_q=block_q,
block_k=block_k)
out_shape = (
jax.ShapeDtypeStruct(shape=(batch_size, q_len, num_heads, v_head_dim), dtype=q.dtype), # out
jax.ShapeDtypeStruct(shape=(batch_size, num_heads, q_len), dtype=jnp.float32) if return_residual else None, # L
)
in_specs = (
# j,k in grid parallelise over batch and head
pl.BlockSpec(
lambda _, j, k: (j, 0, k, 0), (None, q_len, None, qk_head_dim) # bqh(c_qk)
),
pl.BlockSpec(
lambda _, j, k: (j, 0, k, 0), (None, kv_len, None, qk_head_dim) # bkh(c_qk)
),
pl.BlockSpec(
lambda _, j, k: (j, 0, k, 0), (None, kv_len, None, v_head_dim) # bkh(c_v)
),
)
in_specs+= (None,) if (gate_values is None) else (
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, q_len, None, v_head_dim)), # bqh(c_v)
)
# mask might be (b or 1, h or 1, q or 1, k) in shape, we deal with b and/or h broadcast here.
# we deal with q broadcast in the kernel
in_specs+= (None,) if (mask is None) else (
pl.BlockSpec(lambda _, j, k: (
j if mask.shape[0]!=1 else 0,
k if mask.shape[1]!=1 else 0,
0, 0), (None, None,)+mask.shape[2:4]), # bhqk
)
in_specs+= (None,) if (nonbatched_bias is None) else (
pl.BlockSpec(lambda _, j, k: (k, 0, 0), (None, q_len, kv_len)), # hqk
)

out, L = pl.pallas_call(
kernel,
grid=grid,
in_specs=in_specs,
out_specs=(
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, q_len, None, v_head_dim)), # bqh(c_v)
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, q_len)) if return_residual else None, # bhq
),
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
out_shape=out_shape,
debug=debug,
interpret=interpret,
name="flash_attention",
)(q, k, v, gate_values, mask, nonbatched_bias)
return out

class GlobalAttention(hk.Module):
"""Global attention.

Expand Down Expand Up @@ -787,7 +997,6 @@ def __call__(self,
dtype=msa_act.dtype,
init=hk.initializers.RandomNormal(stddev=init_factor))
nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights)

attn_mod = Attention(
c, self.global_config, msa_act.shape[-1])
msa_act = mapping.inference_subbatch(
Expand Down