From df73d2ce9d6bb25ee3182e7aa8d6d41f4e1382b3 Mon Sep 17 00:00:00 2001 From: ZincCat Date: Tue, 22 Oct 2024 16:42:48 -0400 Subject: [PATCH] fix shape check --- README.md | 2 +- examples/benchmark.py | 19 +++-- flaxattention/core/attention.py | 8 +- flaxattention/utils.py | 139 ++++++++++++++++++++++++++++++++ 4 files changed, 155 insertions(+), 13 deletions(-) create mode 100644 flaxattention/utils.py diff --git a/README.md b/README.md index c20f43f..35f0074 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ Float16: |------------------------|----------------------------|----------------------------| | FlaxAttention (Pure JAX) | 0.5692746052518487 | 0.8823547409847379 | | FlaxAttention (Pallas) | **0.13677988620474935** | **0.5575501238927245** | -| Jax Attention (no score_mod) | 0.2551286369562149 | 0.04072062578052282 | +| Jax Attention (no score_mod) | 1.6788566000759602 | 1.0905949068255723 | | FlexAttention (Torch)| **0.11708855209872127** | **0.5104729640297592** | We can see that the forward performance is about 20% slower than the original implementation, while backward about 8% slower. There are still some optimizations to be done. diff --git a/examples/benchmark.py b/examples/benchmark.py index 36ded28..0afaf77 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -155,20 +155,23 @@ def fn(query, key, value): # try jax attention + query_transposed = jnp.moveaxis(query, 1, 2) + key_transposed = jnp.moveaxis(key, 1, 2) + value_transposed = jnp.moveaxis(value, 1, 2) # warm up output = dot_product_attention( - query, - key, - value, + query_transposed, + key_transposed, + value_transposed, ) output.block_until_ready() start = timer() for _ in range(100): output = dot_product_attention( - query, - key, - value, + query_transposed, + key_transposed, + value_transposed, ) output.block_until_ready() end = timer() @@ -184,14 +187,14 @@ def fn1(query, key, value): grad_fn1 = jax.jit(grad_fn1) # warm up - grad = grad_fn1(query, key, value) + grad = grad_fn1(query_transposed, key_transposed, value_transposed) grad.block_until_ready() # print(grad[0, 0, 0]) start = timer() for _ in range(100): - grad = grad_fn1(query, key, value) + grad = grad_fn1(query_transposed, key_transposed, value_transposed) grad.block_until_ready() end = timer() print("Jax dot product attention gradient time taken (no score_mod):", end - start) diff --git a/flaxattention/core/attention.py b/flaxattention/core/attention.py index bf826d6..f31ce88 100644 --- a/flaxattention/core/attention.py +++ b/flaxattention/core/attention.py @@ -252,15 +252,15 @@ def flax_attention_pallas( _validate_embed_dim(query, key, value) if query.ndim != 4 or key.ndim != 4 or value.ndim != 4: raise NotImplementedError("NYI: query, key, and value must be 4D tensors") - if (not enable_gqa) and query.shape[-3] != key.shape[-3]: + if (not enable_gqa) and query.shape[-2] != key.shape[-2]: raise ValueError( f"Expect query and key/value to have the same number of heads " - f"but got Hq={query.shape[-3]} and Hkv={key.shape[-3]}. " + f"but got Hq={query.shape[-2]} and Hkv={key.shape[-2]}. " f"Try setting enable_gqa=True for GQA." ) if enable_gqa: - Hq = query.shape[1] - Hkv = key.shape[1] + Hq = query.shape[2] + Hkv = key.shape[2] if Hq % Hkv != 0: raise ValueError( f"Expect number of query heads to be a multiple of kv heads for GQA " diff --git a/flaxattention/utils.py b/flaxattention/utils.py new file mode 100644 index 0000000..8da21a1 --- /dev/null +++ b/flaxattention/utils.py @@ -0,0 +1,139 @@ +import jax +from jax import numpy as jnp +from typing import Optional +import matplotlib.pyplot as plt +from pathlib import Path +import numpy as np +import math +from .core.common import ( + _score_mod_signature, + _mask_mod_signature, + _vmap_for_bhqkv, + _ModificationType, +) + +Array = jax.Array + +def create_score_mod( + query: Array, + key: Array, + score_mod: Optional[_score_mod_signature], + mask_mod: Optional[_mask_mod_signature], + scale: Optional[float] = None, + batch_idx: int = 0, + head_idx: int = 0, +) -> Array: + B = 1 + H = 1 + M = query.shape[0] + N = key.shape[0] + + b = jnp.arange(0, B) + batch_idx + h = jnp.arange(0, H) + head_idx + m = jnp.arange(0, M) + n = jnp.arange(0, N) + + scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale + type = _ModificationType.SCORE_MOD if score_mod is not None else _ModificationType.MASK_MOD + mod_fn = score_mod if type == _ModificationType.SCORE_MOD else mask_mod + prefix = (0,) if type == _ModificationType.SCORE_MOD else () + mod = _vmap_for_bhqkv(mod_fn, prefix=prefix) + scores = query @ jnp.moveaxis(key, -1, -2) + scores *= scale_factor + scores = scores.reshape(1, 1, M, N) + if type == _ModificationType.SCORE_MOD: + out = mod(scores, b, h, m, n) + else: + out = mod(b, h, m, n) + + return out + + +def _name_to_title(name: str) -> str: + title = name.replace("_", " ") + title = " ".join(word.capitalize() for word in title.split()) + return title + + +def visualize_attention_scores( + query: Array, + key: Array, + score_mod: Optional[_score_mod_signature] = None, + mask_mod: Optional[_mask_mod_signature] = None, + device: str = "cuda", + name: str = "attention_scores", + path: Optional[Path] = None, + batch_idx: int = 0, + head_idx: int = 0, + scale: Optional[float] = None, +): + """ + Generate and save a visualization of attention scores. + + Args: + query (Tensor): Query tensor of shape (batch_size, num_heads, seq_len_q, head_dim). + key (Tensor): Key tensor of shape (batch_size, num_heads, seq_len_k, head_dim). + score_mod (Optional[Callable]): If this is set this will take precedence over the mask_mod. + mask_mod (Optional[Callable]): The mask_mod function used to create block_mask + device (str): Device to run computations on (default: "cuda"). + name (str): Base name for the file and title (default: 'attention_scores'). + path (Path): Path to save the visualization. If None, will be saved to the current working directory. + batch_idx (int): Index of the batch to visualize (default: 0). + head_idx (int): Index of the head to visualize (default: 0). + scale (float): Scale factor to apply to the attention scores. If None, will be set to 1 / sqrt(head_dim). + + Returns: + None + """ + assert ( + score_mod is not None or mask_mod is not None + ), "Must provide either score_mod or mask_mod" + query = query[batch_idx, head_idx, :, :] + key = key[batch_idx, head_idx, :, :] + scores_viz = create_score_mod( + query, + key, + score_mod=score_mod, + mask_mod=mask_mod, + scale=scale, + batch_idx=batch_idx, + head_idx=head_idx, + ) + + suffix_title = f"Batch {batch_idx}, Head {head_idx}" if batch_idx != 0 or head_idx != 0 else "" + + fig, ax = plt.subplots(figsize=(12, 10)) + color = "viridis" if score_mod is not None else "cividis" + im = ax.imshow(scores_viz[0, 0, :, :], aspect="auto", cmap=color) + fig.colorbar(im) + + title = _name_to_title(name) + file_path = Path(name).with_suffix(".png") if path is None else path.with_suffix(".png") + ax.set_title(f"{title}\n{suffix_title}", fontsize=20) + + ax.set_xlabel("Key Tokens", fontsize=18) + ax.set_ylabel("Query Tokens", fontsize=18) + + # Move y-axis ticks and labels to the top + ax.tick_params(axis="x", top=True, labeltop=True, bottom=False, labelbottom=False) + + # Add tick labels if the number of tokens is manageable + num_query_tokens, num_kv_tokens = scores_viz.shape[-2:] + if num_query_tokens <= 32 and num_kv_tokens <= 32: + ax.set_xticks(range(num_kv_tokens)) + rotation = 45 if num_kv_tokens > 12 else 0 + ax.set_xticklabels( + [f"KV{i}" for i in range(num_kv_tokens)], fontsize=16, rotation=rotation + ) + ax.set_yticks(range(num_query_tokens)) + ax.set_yticklabels([f"Q{i}" for i in range(num_query_tokens)], fontsize=16) + # Align grid with pixel boundaries + ax.set_xticks(np.arange(-0.5, num_kv_tokens, 1), minor=True) + ax.set_yticks(np.arange(-0.5, num_query_tokens, 1), minor=True) + ax.grid(which="minor", color="black", linestyle="-", linewidth=2) + + plt.tight_layout() + plt.savefig(file_path, dpi=300, bbox_inches="tight") + plt.close(fig) # Close the figure to free up memory + + print(f"Visualization saved as {file_path}") \ No newline at end of file