diff --git a/alphafold/model/config.py b/alphafold/model/config.py index 447c3e34..ab8875a6 100644 --- a/alphafold/model/config.py +++ b/alphafold/model/config.py @@ -384,6 +384,12 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'use_remat': False, 'zero_init': True, 'eval_dropout': False, + 'use_flash_attention': False, + 'flash': { + 'num_warps': 2, + 'block_q': 64, + 'block_k': 32 + } }, 'heads': { 'distogram': { @@ -618,6 +624,12 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'use_remat': False, 'zero_init': True, 'eval_dropout': False, + 'use_flash_attention': False, + 'flash': { + 'num_warps': 2, + 'block_q': 64, + 'block_k': 32 + } }, 'heads': { 'distogram': { diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index 554c078c..aaf8fc6e 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -30,6 +30,8 @@ import haiku as hk import jax import jax.numpy as jnp +from jax.experimental import pallas as pl +import jax.lax as lax _SOFTMAX_MASK = -1e9 @@ -558,8 +560,8 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): q_data: A tensor of queries, shape [batch_size, N_queries, q_channels]. m_data: A tensor of memories from which the keys and values are projected, shape [batch_size, N_keys, m_channels]. - mask: A mask for the attention, shape [batch_size, N_queries, N_keys]. - nonbatched_bias: Shared bias, shape [N_queries, N_keys]. + mask: A mask for the attention, shape [batch_size or 1, N_heads or 1, N_queries or 1, N_keys]. + nonbatched_bias: Shared bias, shape [N_heads, N_queries, N_keys]. Returns: A float32 tensor of shape [batch_size, N_queries, output_dim]. @@ -573,6 +575,7 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): key_dim = key_dim // num_head value_dim = value_dim // num_head + # weight loading q_weights = hk.get_parameter( 'query_w', shape=(q_data.shape[-1], num_head, key_dim), dtype=q_data.dtype, @@ -586,16 +589,6 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): dtype=q_data.dtype, init=glorot_uniform()) - q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) - k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights) - v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights) - logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) - if nonbatched_bias is not None: - logits += jnp.expand_dims(nonbatched_bias, axis=0) - logits = jnp.where(mask, logits, _SOFTMAX_MASK) - weights = utils.stable_softmax(logits) - weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v) - if self.global_config.zero_init: init = hk.initializers.Constant(0.0) else: @@ -613,13 +606,6 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): dtype=q_data.dtype, init=hk.initializers.Constant(1.0)) - gate_values = jnp.einsum('bqc, chv->bqhv', q_data, - gating_weights) + gating_bias - - gate_values = jax.nn.sigmoid(gate_values) - - weighted_avg *= gate_values - o_weights = hk.get_parameter( 'output_w', shape=(num_head, value_dim, self.output_dim), dtype=q_data.dtype, @@ -628,12 +614,236 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): 'output_b', shape=(self.output_dim,), dtype=q_data.dtype, init=hk.initializers.Constant(0.0)) + + # pre-kernel + if self.config.gating: + gate_values = jnp.einsum('bqc, chv->bqhv', q_data, + gating_weights) + gating_bias + gate_values = jax.nn.sigmoid(gate_values) + else: + gate_values = None + q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) + k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights) + v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights) + # kernel + q_len, k_len = q.shape[1], k.shape[1] + kernel = functools.partial(self.flash_kernel, **self.global_config.flash) if ( + self.global_config.use_flash_attention and + q_len >= self.global_config.flash.block_q and + k_len >= self.global_config.flash.block_k) else self.reference_kernel + weighted_avg = kernel(q, k, v, mask=mask, nonbatched_bias=nonbatched_bias, gate_values=gate_values) + + # post-kernel output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias return output - + @staticmethod + def reference_kernel(q, k, v, mask, nonbatched_bias=None, gate_values=None): + # kernel + logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + if nonbatched_bias is not None: + logits += jnp.expand_dims(nonbatched_bias, axis=0) + logits = jnp.where(mask, logits, _SOFTMAX_MASK) + weights = utils.stable_softmax(logits) + weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v) + if gate_values is not None: + weighted_avg *= gate_values + return weighted_avg + + @staticmethod + def flash_pallas_kernel( + # Input arrays + q_ref: jax.Array, + k_ref: jax.Array, + v_ref: jax.Array, + gate_values_ref: jax.Array | None, + mask_ref: jax.Array | None, + logit_bias_ref: jax.Array | None, + # Output arrays + o_ref: jax.Array, + L_ref: jax.Array | None, + # Options + block_q: int, + block_k: int, + ): + # convenience functions to match syntax of FlashAttention2 forward pass (Algorithm 1) + _log = jnp.log + _exp = jnp.exp + _maximum = jnp.maximum + _rowmax = lambda x: x.max(axis=-1) + _rowsum = lambda x: x.sum(axis=-1) + _diag = lambda x: x[:, None] + _dot = lambda a, b: pl.dot(a.astype(b.dtype), b) # matches dtype of second element, in dot(p,v) this pushes to use bfloat16 + _generate_1d_bounds_mask = lambda ref, start, block_size: lax.iota(jnp.int32, block_size)+start*block_size < ref.shape[0] if (ref.shape[0]%block_size!=0) else None + # When head_dim < 16 pl.dot is not supported so we pad up + _head_dim_mask = lambda ref: jax.lax.iota(jnp.int32, 16)=16) else \ + pl.load(ref, (slice, pl.dslice(0,16)), **_load_mask_kwargs((dim0_mask, _head_dim_mask(ref)))) + def _padding_store(ref, slice, val, dim0_mask=None): + pl.store(ref, (slice, pl.dslice(None)), val, mask=_combine_masks(dim0_mask, None)) if (ref.shape[-1]>=16) else \ + pl.store(ref, (slice, pl.dslice(0,16)), val, mask=_combine_masks(dim0_mask, _head_dim_mask(ref))) + + # m and l are updated during the kv loop. + # o is the buffer where we accumulate the output on sram. + m_init = jnp.zeros(block_q, dtype=jnp.float32) - 1e9 + l_init = jnp.zeros(block_q, dtype=jnp.float32) + o_init = jnp.zeros((block_q, max(v_ref.shape[-1], 16)), dtype=jnp.float32) + + # Grid loops over q + start_q = pl.program_id(0) + q_slice = pl.dslice(start_q * block_q, block_q) + q_bounds_mask = _generate_1d_bounds_mask(q_ref, start_q, block_q) + q = _padding_load(q_ref, q_slice, q_bounds_mask) + + # Here we only loop over blocks of kv to process entire seq_len, the loop over + # blocks of q is carried out by the grid. + def body(start_k, carry): + o_prev, m_prev, l_prev = carry + k_slice = pl.dslice(start_k * block_k, block_k) + kv_bounds_mask = _generate_1d_bounds_mask(k_ref, start_k, block_k) + + k = _padding_load(k_ref, k_slice, kv_bounds_mask) + v = _padding_load(v_ref, k_slice, kv_bounds_mask) + qk = pl.dot(q, k.T) # (block_q, block_k) + + if (logit_bias_ref is not None): + logit_bias = pl.load(logit_bias_ref, (q_slice, k_slice), + **_load_mask_kwargs((q_bounds_mask, kv_bounds_mask))) + qk += logit_bias + + if (mask_ref is not None): + kv_only_mask = (mask_ref.shape[0]==1) + mask = pl.load(mask_ref, (pl.dslice(0,1) if kv_only_mask else q_slice, k_slice), + **_load_mask_kwargs((None if kv_only_mask else q_bounds_mask, kv_bounds_mask))) + qk = jnp.where(mask, qk, _SOFTMAX_MASK) + + # boundary checks + bounds_mask = _combine_masks(q_bounds_mask, kv_bounds_mask) + if (bounds_mask is not None): + qk = jnp.where(bounds_mask, qk, _SOFTMAX_MASK) + + # Mapping of indexing to FlashAttention2 papers notation: + # x_{i}^{j-1} = x_prev + # x_{i}^{j} = x + s = qk + m = _maximum(m_prev, _rowmax(s)) + p = _exp(s-_diag(m)) # _diag not present in paper, but is needed + l = _exp(m_prev-m)*l_prev + _rowsum(p) + o = _diag(_exp(m_prev-m)) * o_prev + _dot(p,v) + return o, m, l + + kv_len = k_ref.shape[0] + n_blocks = pl.cdiv(kv_len, block_k) + o, m, l = lax.fori_loop(0, n_blocks, body, (o_init, m_init, l_init)) + + if (gate_values_ref is not None): + gate_values = _padding_load(gate_values_ref, q_slice, q_bounds_mask) + o *= gate_values + + o *= _diag(1./l) + + if (L_ref is not None): + L = m + _log(l) + pl.store(L_ref, (q_slice,), L, mask=q_bounds_mask) + # Write output to dram. + o = o.astype(o_ref.dtype) + _padding_store(o_ref, q_slice, o, q_bounds_mask) + + @classmethod + def flash_kernel( + cls, + q, + k, + v, + gate_values: jnp.ndarray | None = None, + mask: jnp.ndarray | None = None, + nonbatched_bias: jnp.ndarray | None = None, + return_residual: bool = False, + block_q: int = 32, + block_k: int = 32, + num_warps: int | None = 2, + num_stages: int = 2, + grid: tuple[int, ...] | None = None, + interpret: bool = False, + debug: bool = False, + ): + if (mask is not None): + mask = mask.astype(jnp.bool_) + batch_size, q_len, num_heads, qk_head_dim = q.shape + _, kv_len, _, _ = k.shape + _, _, _, v_head_dim = v.shape + + block_q = min(block_q, q_len) + block_k = min(block_k, kv_len) + + if grid is None: + grid = (pl.cdiv(q_len, block_q), batch_size, num_heads) + + kernel = functools.partial(cls.flash_pallas_kernel, + block_q=block_q, + block_k=block_k) + out_shape = ( + jax.ShapeDtypeStruct(shape=(batch_size, q_len, num_heads, v_head_dim), dtype=q.dtype), # out + jax.ShapeDtypeStruct(shape=(batch_size, num_heads, q_len), dtype=jnp.float32) if return_residual else None, # L + ) + in_specs = ( + # j,k in grid parallelise over batch and head + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, q_len, None, qk_head_dim) # bqh(c_qk) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, kv_len, None, qk_head_dim) # bkh(c_qk) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, kv_len, None, v_head_dim) # bkh(c_v) + ), + ) + in_specs+= (None,) if (gate_values is None) else ( + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, q_len, None, v_head_dim)), # bqh(c_v) + ) + # mask might be (b or 1, h or 1, q or 1, k) in shape, we deal with b and/or h broadcast here. + # we deal with q broadcast in the kernel + in_specs+= (None,) if (mask is None) else ( + pl.BlockSpec(lambda _, j, k: ( + j if mask.shape[0]!=1 else 0, + k if mask.shape[1]!=1 else 0, + 0, 0), (None, None,)+mask.shape[2:4]), # bhqk + ) + in_specs+= (None,) if (nonbatched_bias is None) else ( + pl.BlockSpec(lambda _, j, k: (k, 0, 0), (None, q_len, kv_len)), # hqk + ) + + out, L = pl.pallas_call( + kernel, + grid=grid, + in_specs=in_specs, + out_specs=( + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, q_len, None, v_head_dim)), # bqh(c_v) + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, q_len)) if return_residual else None, # bhq + ), + compiler_params=dict( + triton=dict(num_warps=num_warps, num_stages=num_stages) + ), + out_shape=out_shape, + debug=debug, + interpret=interpret, + name="flash_attention", + )(q, k, v, gate_values, mask, nonbatched_bias) + return out + class GlobalAttention(hk.Module): """Global attention. @@ -787,7 +997,6 @@ def __call__(self, dtype=msa_act.dtype, init=hk.initializers.RandomNormal(stddev=init_factor)) nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) - attn_mod = Attention( c, self.global_config, msa_act.shape[-1]) msa_act = mapping.inference_subbatch(