From f86ad2973cd2d805086e0a8187314b2ca322a44b Mon Sep 17 00:00:00 2001 From: Oliver Dutton Date: Fri, 19 Apr 2024 21:54:05 +0100 Subject: [PATCH 1/9] feat: flash attention --- alphafold/model/config.py | 6 + alphafold/model/modules.py | 236 +++++++++++++++++++++++++++++++++---- 2 files changed, 221 insertions(+), 21 deletions(-) diff --git a/alphafold/model/config.py b/alphafold/model/config.py index 447c3e34b..53826c785 100644 --- a/alphafold/model/config.py +++ b/alphafold/model/config.py @@ -384,6 +384,12 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'use_remat': False, 'zero_init': True, 'eval_dropout': False, + 'use_flash_attention': False, + 'flash': { + 'num_warps': 2, + 'block_q': 64, + 'block_k': 32 + } }, 'heads': { 'distogram': { diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index 554c078c0..1659c5bfc 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -30,6 +30,8 @@ import haiku as hk import jax import jax.numpy as jnp +from jax.experimental import pallas as pl +import jax.lax as lax _SOFTMAX_MASK = -1e9 @@ -558,8 +560,8 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): q_data: A tensor of queries, shape [batch_size, N_queries, q_channels]. m_data: A tensor of memories from which the keys and values are projected, shape [batch_size, N_keys, m_channels]. - mask: A mask for the attention, shape [batch_size, N_queries, N_keys]. - nonbatched_bias: Shared bias, shape [N_queries, N_keys]. + mask: A mask for the attention, shape [batch_size, N_heads or 1, N_queries or 1, N_keys]. + nonbatched_bias: Shared bias, shape [N_heads, N_queries, N_keys]. Returns: A float32 tensor of shape [batch_size, N_queries, output_dim]. @@ -573,6 +575,7 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): key_dim = key_dim // num_head value_dim = value_dim // num_head + # weight loading q_weights = hk.get_parameter( 'query_w', shape=(q_data.shape[-1], num_head, key_dim), dtype=q_data.dtype, @@ -586,16 +589,6 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): dtype=q_data.dtype, init=glorot_uniform()) - q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) - k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights) - v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights) - logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) - if nonbatched_bias is not None: - logits += jnp.expand_dims(nonbatched_bias, axis=0) - logits = jnp.where(mask, logits, _SOFTMAX_MASK) - weights = utils.stable_softmax(logits) - weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v) - if self.global_config.zero_init: init = hk.initializers.Constant(0.0) else: @@ -613,13 +606,6 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): dtype=q_data.dtype, init=hk.initializers.Constant(1.0)) - gate_values = jnp.einsum('bqc, chv->bqhv', q_data, - gating_weights) + gating_bias - - gate_values = jax.nn.sigmoid(gate_values) - - weighted_avg *= gate_values - o_weights = hk.get_parameter( 'output_w', shape=(num_head, value_dim, self.output_dim), dtype=q_data.dtype, @@ -628,12 +614,221 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): 'output_b', shape=(self.output_dim,), dtype=q_data.dtype, init=hk.initializers.Constant(0.0)) + + # pre-kernel + if self.config.gating: + gate_values = jnp.einsum('bqc, chv->bqhv', q_data, + gating_weights) + gating_bias + gate_values = jax.nn.sigmoid(gate_values) + else: + gate_values = None + q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5) + k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights) + v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights) + # kernel + q_len, k_len = q.shape[1], k.shape[1] + kernel = functools.partial(self.flash_kernel, **self.global_config.flash) if ( + ('use_flash_attention' in self.global_config) and + self.global_config.use_flash_attention and + q_len >= self.global_config.flash.block_q and + k_len >= self.global_config.flash.block_k) else self.reference_kernel + weighted_avg = kernel(q, k, v, mask=mask, nonbatched_bias=nonbatched_bias, gate_values=gate_values) + + # post-kernel output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_bias return output - + @staticmethod + def reference_kernel(q, k, v, mask, nonbatched_bias=None, gate_values=None): + # kernel + logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + if nonbatched_bias is not None: + logits += jnp.expand_dims(nonbatched_bias, axis=0) + logits = jnp.where(mask, logits, _SOFTMAX_MASK) + weights = utils.stable_softmax(logits) + weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v) + if gate_values is not None: + weighted_avg *= gate_values + return weighted_avg + + @staticmethod + def flash_pallas_kernel( + # Input arrays + q_ref: jax.Array, + k_ref: jax.Array, + v_ref: jax.Array, + gate_values_ref: jax.Array | None, + mask_ref: jax.Array | None, + logit_bias_ref: jax.Array | None, + # Output arrays + o_ref: jax.Array, + L_ref: jax.Array | None, + # Options + block_q: int, + block_k: int, + ): + # convenience functions to match syntax of FlashAttention2 forward pass (Algorithm 1) + _log = jnp.log + _exp = jnp.exp + _maximum = jnp.maximum + _rowmax = lambda x: x.max(axis=-1) + _rowsum = lambda x: x.sum(axis=-1) + _diag = lambda x: x[:, None] + _dot = pl.dot + _generate_1d_bounds_mask = lambda ref, start, block_size: lax.iota(jnp.int32, block_size)+start*block_size < ref.shape[0] + # When head_dim < 16 pl.dot is not supported so we pad up + _head_dim_mask = lambda ref: jax.lax.iota(jnp.int32, 16)=16) else \ + pl.load(ref, (slice, pl.dslice(0,16)), mask=_head_dim_mask(ref)[None,:], other=0.) + def _padding_store(ref, slice, val, dim0_mask): + pl.store(ref, (slice, pl.dslice(None)), val, mask=dim0_mask[:,None]) if (ref.shape[-1]>=16) else \ + pl.store(ref, (slice, pl.dslice(0,16)), val, mask=(dim0_mask[:,None] & _head_dim_mask(ref)[None,:])) + + # m and l are updated during the kv loop. + # o is the buffer where we accumulate the output on sram. + m_init = jnp.zeros(block_q, dtype=jnp.float32) - 1e9 + l_init = jnp.zeros(block_q, dtype=jnp.float32) + o_init = jnp.zeros((block_q, max(v_ref.shape[-1], 16)), dtype=jnp.float32) + + # Grid loops over q + start_q = pl.program_id(0) + q_slice = pl.dslice(start_q * block_q, block_q) + q = _padding_load(q_ref, q_slice) + q_bounds_mask = _generate_1d_bounds_mask(q_ref, start_q, block_q) + + # Here we only loop over blocks of kv to process entire seq_len, the loop over + # blocks of q is carried out by the grid. + def body(start_k, carry): + o_prev, m_prev, l_prev = carry + k_slice = pl.dslice(start_k * block_k, block_k) + + k = _padding_load(k_ref, k_slice) + v = _padding_load(v_ref, k_slice) + qk = pl.dot(q, k.T) # (block_q, block_k) + + if (logit_bias_ref is not None): + logit_bias = pl.load(logit_bias_ref, (q_slice, k_slice)) + qk += logit_bias + + if (mask_ref is not None): + kv_only_mask = (mask_ref.shape[0]==1) + mask = pl.load(mask_ref, ( + pl.dslice(0,1) if kv_only_mask else q_slice, + k_slice)) + qk = jnp.where(mask, qk, _SOFTMAX_MASK) + + # boundary checks + kv_bounds_mask = _generate_1d_bounds_mask(k_ref, start_k, block_k) + bounds_mask = (q_bounds_mask[:,None] & kv_bounds_mask[None,:]) + qk = jnp.where(bounds_mask, qk, _SOFTMAX_MASK) + + # Mapping of indexing to FlashAttention2 papers notation: + # x_{i}^{j-1} = x_prev + # x_{i}^{j} = x + s = qk + rowmax_s = _rowmax(s) + m = _maximum(m_prev, rowmax_s) + p = _exp(s-_diag(m)) # _diag not present in paper, but is needed + l = _exp(m_prev-m)*l_prev + _rowsum(p) + o = _diag(_exp(m_prev-m)) * o_prev + _dot(p,v) + return o, m, l + + kv_len = k_ref.shape[0] + n_blocks = pl.cdiv(kv_len, block_k) + o, m, l = lax.fori_loop(0, n_blocks, body, (o_init, m_init, l_init)) + + if (gate_values_ref is not None): + gate_values = _padding_load(gate_values_ref, q_slice) + o *= gate_values + + o *= _diag(1./l) + + if (L_ref is not None): + L = m + _log(l) + pl.store(L_ref, (q_slice,), L, mask=q_bounds_mask) + # Write output to dram. + o = o.astype(o_ref.dtype) + _padding_store(o_ref, q_slice, o, q_bounds_mask) + + @classmethod + def flash_kernel( + cls, + q, + k, + v, + gate_values: jnp.ndarray | None = None, + mask: jnp.ndarray | None = None, + nonbatched_bias: jnp.ndarray | None = None, + return_residual: bool = False, + block_q: int = 32, + block_k: int = 32, + num_warps: int | None = 2, + num_stages: int = 2, + grid: tuple[int, ...] | None = None, + interpret: bool = False, + debug: bool = False, + ): + mask = mask.astype(jnp.bool_) + batch_size, q_len, num_heads, qk_head_dim = q.shape + _, kv_len, _, _ = k.shape + _, _, _, v_head_dim = v.shape + + block_q = min(block_q, q_len) + block_k = min(block_k, kv_len) + + if grid is None: + grid = (pl.cdiv(q_len, block_q), batch_size, num_heads) + + kernel = functools.partial(cls.flash_pallas_kernel, + block_q=block_q, + block_k=block_k) + out_shape = ( + jax.ShapeDtypeStruct(shape=(batch_size, q_len, num_heads, v_head_dim), dtype=q.dtype), # out + jax.ShapeDtypeStruct(shape=(batch_size, num_heads, q_len), dtype=jnp.float32) if return_residual else None, # L + ) + in_specs = ( + # j,k in grid parallelise over batch and head + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, q_len, None, qk_head_dim) # bqh(c_qk) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, kv_len, None, qk_head_dim) # bkh(c_qk) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, kv_len, None, v_head_dim) # bkh(c_v) + ), + ) + in_specs+= (None,) if (gate_values is None) else ( + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, q_len, None, v_head_dim)), # bqh(c_v) + ) + in_specs+= (None,) if (mask is None) else ( + pl.BlockSpec(lambda _, j, k: (j, k if mask.shape[1]!=1 else 0, 0, 0), (None, None,)+mask.shape[2:4]), # bhqk + ) + in_specs+= (None,) if (nonbatched_bias is None) else ( + pl.BlockSpec(lambda _, j, k: (k, 0, 0), (None, q_len, kv_len)), # hqk + ) + + out, L = pl.pallas_call( + kernel, + grid=grid, + in_specs=in_specs, + out_specs=( + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, q_len, None, v_head_dim)), # bqh(c_v) + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, q_len)) if return_residual else None, # bhq + ), + compiler_params=dict( + triton=dict(num_warps=num_warps, num_stages=num_stages) + ), + out_shape=out_shape, + debug=debug, + interpret=interpret, + name="flash_attention", + )(q, k, v, gate_values, mask, nonbatched_bias) + return out + class GlobalAttention(hk.Module): """Global attention. @@ -787,7 +982,6 @@ def __call__(self, dtype=msa_act.dtype, init=hk.initializers.RandomNormal(stddev=init_factor)) nonbatched_bias = jnp.einsum('qkc,ch->hqk', pair_act, weights) - attn_mod = Attention( c, self.global_config, msa_act.shape[-1]) msa_act = mapping.inference_subbatch( From d4fdf816c3f0c37f46ac9a63b5794fee2398b0be Mon Sep 17 00:00:00 2001 From: Oliver Dutton Date: Sat, 20 Apr 2024 14:40:53 +0100 Subject: [PATCH 2/9] feat: allow bfloat16 flash attention --- alphafold/model/config.py | 12 +++++++++--- alphafold/model/modules.py | 3 +-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/alphafold/model/config.py b/alphafold/model/config.py index 53826c785..ab8875a6e 100644 --- a/alphafold/model/config.py +++ b/alphafold/model/config.py @@ -386,10 +386,10 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'eval_dropout': False, 'use_flash_attention': False, 'flash': { - 'num_warps': 2, - 'block_q': 64, + 'num_warps': 2, + 'block_q': 64, 'block_k': 32 - } + } }, 'heads': { 'distogram': { @@ -624,6 +624,12 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'use_remat': False, 'zero_init': True, 'eval_dropout': False, + 'use_flash_attention': False, + 'flash': { + 'num_warps': 2, + 'block_q': 64, + 'block_k': 32 + } }, 'heads': { 'distogram': { diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index 1659c5bfc..bea729356 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -629,7 +629,6 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): # kernel q_len, k_len = q.shape[1], k.shape[1] kernel = functools.partial(self.flash_kernel, **self.global_config.flash) if ( - ('use_flash_attention' in self.global_config) and self.global_config.use_flash_attention and q_len >= self.global_config.flash.block_q and k_len >= self.global_config.flash.block_k) else self.reference_kernel @@ -676,7 +675,7 @@ def flash_pallas_kernel( _rowmax = lambda x: x.max(axis=-1) _rowsum = lambda x: x.sum(axis=-1) _diag = lambda x: x[:, None] - _dot = pl.dot + _dot = lambda a, b: pl.dot(a.astype(b.dtype), b) # matches dtype of second element, in dot(p,v) this pushes to use bfloat16 _generate_1d_bounds_mask = lambda ref, start, block_size: lax.iota(jnp.int32, block_size)+start*block_size < ref.shape[0] # When head_dim < 16 pl.dot is not supported so we pad up _head_dim_mask = lambda ref: jax.lax.iota(jnp.int32, 16) Date: Sun, 21 Apr 2024 10:13:14 +0100 Subject: [PATCH 3/9] fix: support batch dim broadcast mask TemplateEmbedding uses attention with batch dim broadcast which wasn't supported `mask = template_mask[None, None, None,:]` --- alphafold/model/modules.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index bea729356..bc794528e 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -560,7 +560,7 @@ def __call__(self, q_data, m_data, mask, nonbatched_bias=None): q_data: A tensor of queries, shape [batch_size, N_queries, q_channels]. m_data: A tensor of memories from which the keys and values are projected, shape [batch_size, N_keys, m_channels]. - mask: A mask for the attention, shape [batch_size, N_heads or 1, N_queries or 1, N_keys]. + mask: A mask for the attention, shape [batch_size or 1, N_heads or 1, N_queries or 1, N_keys]. nonbatched_bias: Shared bias, shape [N_heads, N_queries, N_keys]. Returns: @@ -803,8 +803,13 @@ def flash_kernel( in_specs+= (None,) if (gate_values is None) else ( pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, q_len, None, v_head_dim)), # bqh(c_v) ) + # mask might be (b or 1, h or 1, q or 1, k) in shape, we deal with b and/or h broadcast here. + # we deal with q broadcast in the kernel in_specs+= (None,) if (mask is None) else ( - pl.BlockSpec(lambda _, j, k: (j, k if mask.shape[1]!=1 else 0, 0, 0), (None, None,)+mask.shape[2:4]), # bhqk + pl.BlockSpec(lambda _, j, k: ( + j if mask.shape[0]!=1 else 0, + k if mask.shape[1]!=1 else 0, + 0, 0), (None, None,)+mask.shape[2:4]), # bhqk ) in_specs+= (None,) if (nonbatched_bias is None) else ( pl.BlockSpec(lambda _, j, k: (k, 0, 0), (None, q_len, kv_len)), # hqk From ef77bf60e217c6244e3e8bc179ab4981b92221ff Mon Sep 17 00:00:00 2001 From: Oliver Dutton Date: Sun, 21 Apr 2024 22:19:25 +0100 Subject: [PATCH 4/9] fix: add null callback, lowers to nan avoiding MHLO --- alphafold/model/modules.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index bc794528e..3ce20b78a 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -831,6 +831,9 @@ def flash_kernel( interpret=interpret, name="flash_attention", )(q, k, v, gate_values, mask, nonbatched_bias) + # NOTE: I found nans on an occassion with model_2_multimer_v3. This solved it. + # Leaving this callback in as should do no harm and later work out the change in MHLO + jax.debug.callback(lambda: None) return out class GlobalAttention(hk.Module): From d4516d83aaf65aee2e2c90ca85b86acacd464c0f Mon Sep 17 00:00:00 2001 From: Oliver Dutton Date: Mon, 22 Apr 2024 08:40:33 +0100 Subject: [PATCH 5/9] fix: index guard all loads Removes any OOB indexing. Previously I've allowed out-of-bounds loads and fixed them by masks in qk. I've seen nan's appear which disappear with minorly varying MHLO. This commit removes any OOB indexing. --- alphafold/model/modules.py | 54 ++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index 3ce20b78a..f5bcda9ba 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -662,8 +662,8 @@ def flash_pallas_kernel( mask_ref: jax.Array | None, logit_bias_ref: jax.Array | None, # Output arrays - o_ref: jax.Array, - L_ref: jax.Array | None, + o_ref: jax.Array, + L_ref: jax.Array | None, # Options block_q: int, block_k: int, @@ -676,15 +676,25 @@ def flash_pallas_kernel( _rowsum = lambda x: x.sum(axis=-1) _diag = lambda x: x[:, None] _dot = lambda a, b: pl.dot(a.astype(b.dtype), b) # matches dtype of second element, in dot(p,v) this pushes to use bfloat16 - _generate_1d_bounds_mask = lambda ref, start, block_size: lax.iota(jnp.int32, block_size)+start*block_size < ref.shape[0] + _generate_1d_bounds_mask = lambda ref, start, block_size: lax.iota(jnp.int32, block_size)+start*block_size < ref.shape[0] if (ref.shape[0]%block_size!=0) else None # When head_dim < 16 pl.dot is not supported so we pad up _head_dim_mask = lambda ref: jax.lax.iota(jnp.int32, 16)=16) else \ - pl.load(ref, (slice, pl.dslice(0,16)), mask=_head_dim_mask(ref)[None,:], other=0.) - def _padding_store(ref, slice, val, dim0_mask): - pl.store(ref, (slice, pl.dslice(None)), val, mask=dim0_mask[:,None]) if (ref.shape[-1]>=16) else \ - pl.store(ref, (slice, pl.dslice(0,16)), val, mask=(dim0_mask[:,None] & _head_dim_mask(ref)[None,:])) + def _combine_masks(*masks: jax.Array | None): + # Combines a set of masks matching the dimension of the tensor to store. None for an all-True array + dims = set(range(len(masks))) + masks = [jnp.expand_dims(mask, tuple(dims-set({i}))) for i, mask in enumerate(masks) if mask is not None] + if (len(masks)==0): return None; # all None case - don't use a mask + mask = functools.reduce(lambda x, y: x & y, masks) + return mask + def _load_mask_kwargs(masks: tuple[jax.Array | None, ...], other=0.): + mask = _combine_masks(*masks) + return {'mask':mask, 'other':other} if (mask is not None) else {} + def _padding_load(ref, slice, dim0_mask=None): + return pl.load(ref, (slice, pl.dslice(None)), **_load_mask_kwargs((dim0_mask, None))) if (ref.shape[-1]>=16) else \ + pl.load(ref, (slice, pl.dslice(0,16)), **_load_mask_kwargs((dim0_mask, _head_dim_mask(ref)))) + def _padding_store(ref, slice, val, dim0_mask=None): + pl.store(ref, (slice, pl.dslice(None)), val, mask=_combine_masks(dim0_mask, None)) if (ref.shape[-1]>=16) else \ + pl.store(ref, (slice, pl.dslice(0,16)), val, mask=_combine_masks(dim0_mask, _head_dim_mask(ref))) # m and l are updated during the kv loop. # o is the buffer where we accumulate the output on sram. @@ -695,34 +705,35 @@ def _padding_store(ref, slice, val, dim0_mask): # Grid loops over q start_q = pl.program_id(0) q_slice = pl.dslice(start_q * block_q, block_q) - q = _padding_load(q_ref, q_slice) q_bounds_mask = _generate_1d_bounds_mask(q_ref, start_q, block_q) + q = _padding_load(q_ref, q_slice, q_bounds_mask) # Here we only loop over blocks of kv to process entire seq_len, the loop over # blocks of q is carried out by the grid. def body(start_k, carry): o_prev, m_prev, l_prev = carry k_slice = pl.dslice(start_k * block_k, block_k) + kv_bounds_mask = _generate_1d_bounds_mask(k_ref, start_k, block_k) - k = _padding_load(k_ref, k_slice) - v = _padding_load(v_ref, k_slice) + k = _padding_load(k_ref, k_slice, kv_bounds_mask) + v = _padding_load(v_ref, k_slice, kv_bounds_mask) qk = pl.dot(q, k.T) # (block_q, block_k) if (logit_bias_ref is not None): - logit_bias = pl.load(logit_bias_ref, (q_slice, k_slice)) + logit_bias = pl.load(logit_bias_ref, (q_slice, k_slice), + **_load_mask_kwargs((q_bounds_mask, kv_bounds_mask))) qk += logit_bias if (mask_ref is not None): kv_only_mask = (mask_ref.shape[0]==1) - mask = pl.load(mask_ref, ( - pl.dslice(0,1) if kv_only_mask else q_slice, - k_slice)) + mask = pl.load(mask_ref, (pl.dslice(0,1) if kv_only_mask else q_slice, k_slice), + **_load_mask_kwargs((None if kv_only_mask else q_bounds_mask, kv_bounds_mask))) qk = jnp.where(mask, qk, _SOFTMAX_MASK) # boundary checks - kv_bounds_mask = _generate_1d_bounds_mask(k_ref, start_k, block_k) - bounds_mask = (q_bounds_mask[:,None] & kv_bounds_mask[None,:]) - qk = jnp.where(bounds_mask, qk, _SOFTMAX_MASK) + bounds_mask = _combine_masks(q_bounds_mask, kv_bounds_mask) + if (bounds_mask is not None): + qk = jnp.where(bounds_mask, qk, _SOFTMAX_MASK) # Mapping of indexing to FlashAttention2 papers notation: # x_{i}^{j-1} = x_prev @@ -740,7 +751,7 @@ def body(start_k, carry): o, m, l = lax.fori_loop(0, n_blocks, body, (o_init, m_init, l_init)) if (gate_values_ref is not None): - gate_values = _padding_load(gate_values_ref, q_slice) + gate_values = _padding_load(gate_values_ref, q_slice, q_bounds_mask) o *= gate_values o *= _diag(1./l) @@ -831,9 +842,6 @@ def flash_kernel( interpret=interpret, name="flash_attention", )(q, k, v, gate_values, mask, nonbatched_bias) - # NOTE: I found nans on an occassion with model_2_multimer_v3. This solved it. - # Leaving this callback in as should do no harm and later work out the change in MHLO - jax.debug.callback(lambda: None) return out class GlobalAttention(hk.Module): From 8a8c1c9f750af0f0b8796a37cc3c247c6204640d Mon Sep 17 00:00:00 2001 From: Oliver Dutton Date: Mon, 22 Apr 2024 08:40:51 +0100 Subject: [PATCH 6/9] perf: use exp2 in place of exp for speed --- alphafold/model/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index f5bcda9ba..4351e48c9 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -669,8 +669,8 @@ def flash_pallas_kernel( block_k: int, ): # convenience functions to match syntax of FlashAttention2 forward pass (Algorithm 1) - _log = jnp.log - _exp = jnp.exp + _log = jnp.log2 + _exp = jnp.exp2 _maximum = jnp.maximum _rowmax = lambda x: x.max(axis=-1) _rowsum = lambda x: x.sum(axis=-1) From e1d2eb3e80cba8b8e96c65a84035ed8974001560 Mon Sep 17 00:00:00 2001 From: Oliver Dutton Date: Mon, 22 Apr 2024 09:08:14 +0100 Subject: [PATCH 7/9] fix: revert exp2 change, output variations too large --- alphafold/model/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index 4351e48c9..f5bcda9ba 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -669,8 +669,8 @@ def flash_pallas_kernel( block_k: int, ): # convenience functions to match syntax of FlashAttention2 forward pass (Algorithm 1) - _log = jnp.log2 - _exp = jnp.exp2 + _log = jnp.log + _exp = jnp.exp _maximum = jnp.maximum _rowmax = lambda x: x.max(axis=-1) _rowsum = lambda x: x.sum(axis=-1) From 80af62325a96175bba8ac6a64c4d1feea10e9e24 Mon Sep 17 00:00:00 2001 From: Oliver Dutton Date: Mon, 22 Apr 2024 18:21:03 +0100 Subject: [PATCH 8/9] style: cleanup --- alphafold/model/modules.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index f5bcda9ba..ae1a00fd4 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -739,8 +739,7 @@ def body(start_k, carry): # x_{i}^{j-1} = x_prev # x_{i}^{j} = x s = qk - rowmax_s = _rowmax(s) - m = _maximum(m_prev, rowmax_s) + m = _maximum(m_prev, _rowmax(s)) p = _exp(s-_diag(m)) # _diag not present in paper, but is needed l = _exp(m_prev-m)*l_prev + _rowsum(p) o = _diag(_exp(m_prev-m)) * o_prev + _dot(p,v) From 1787c30b2e372377bd4ed330ace872e6cc2ced7b Mon Sep 17 00:00:00 2001 From: Oliver Dutton Date: Tue, 23 Apr 2024 20:21:07 +0000 Subject: [PATCH 9/9] fix: mask=None case bug --- alphafold/model/modules.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index ae1a00fd4..aaf8fc6e9 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -780,7 +780,8 @@ def flash_kernel( interpret: bool = False, debug: bool = False, ): - mask = mask.astype(jnp.bool_) + if (mask is not None): + mask = mask.astype(jnp.bool_) batch_size, q_len, num_heads, qk_head_dim = q.shape _, kv_len, _, _ = k.shape _, _, _, v_head_dim = v.shape