diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 3c4c5d76e..aa9edc5f6 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,3 +1,5 @@ +# If .venv is already setup with python3.8, it will use python3.8. To use 3.11 remove it first. + # Use Nvidia Ubuntu 20 base (includes CUDA if a supported GPU is present) # https://hub.docker.com/r/nvidia/cuda FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04@sha256:55211df43bf393d3393559d5ab53283d4ebc3943d802b04546a24f3345825bd9 @@ -17,18 +19,26 @@ RUN groupadd --gid $USER_GID $USERNAME \ && chmod 0440 /etc/sudoers.d/$USERNAME # Install dependencies -RUN sudo apt-get update && \ +RUN apt-get update && \ + DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \ + software-properties-common && \ + add-apt-repository -y ppa:deadsnakes/ppa && \ + apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \ build-essential \ - python3.9 \ - python3.9-dev \ - python3.9-distutils \ - python3.9-venv \ + python3.11 \ + python3.11-dev \ + python3.11-distutils \ + python3.11-venv \ curl \ - git + git && \ + # Update python3 default to point to python3.11 + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1 && \ + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 2 && \ + update-alternatives --set python3 /usr/bin/python3.11 # User the new user USER $USERNAME # Install poetry -RUN curl -sSL https://install.python-poetry.org | python3 - +RUN curl -sSL https://install.python-poetry.org | python3.11 - diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index fb686122d..552798b1c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -69,6 +69,8 @@ jobs: poetry install --with dev - name: Unit Test run: make unit-test + env: + HF_TOKEN: ${{ vars.HF_TOKEN }} - name: Acceptance Test run: make acceptance-test - name: Build check @@ -106,6 +108,8 @@ jobs: run: poetry run mypy . - name: Test Suite with Coverage Report run: make coverage-report-test + env: + HF_TOKEN: ${{ vars.HF_TOKEN }} - name: Build check run: poetry build - name: Upload Coverage Report Artifact @@ -195,7 +199,7 @@ jobs: - name: Build Docs run: poetry run build-docs env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} + HF_TOKEN: ${{ vars.HF_TOKEN }} - name: Upload Docs Artifact uses: actions/upload-artifact@v3 with: diff --git a/.gitignore b/.gitignore index 61589404d..978e887aa 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ docs/build .pylintrc docs/source/generated **.orig +.venv diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb index 62b5814d4..2f28a37a1 100644 --- a/demos/Colab_Compatibility.ipynb +++ b/demos/Colab_Compatibility.ipynb @@ -16,9 +16,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_57027/2944939757.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", " ipython.magic(\"load_ext autoreload\")\n", - "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_57027/2944939757.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", " ipython.magic(\"autoreload 2\")\n" ] } @@ -51,28 +51,28 @@ " %pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n", " %pip install torch\n", " %pip install tiktoken\n", - " %pip install transformer_lens\n", + " # %pip install transformer_lens\n", " %pip install transformers_stream_generator\n", " # !huggingface-cli login --token NEEL'S TOKEN" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "TransformerLens currently supports 190 models out of the box.\n" + "TransformerLens currently supports 206 models out of the box.\n" ] } ], "source": [ "import torch\n", "\n", - "from transformer_lens import HookedTransformer, HookedEncoderDecoder, loading\n", + "from transformer_lens import HookedTransformer, HookedEncoderDecoder, HookedEncoder, loading\n", "from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer\n", "from typing import List\n", "import gc\n", @@ -144,11 +144,11 @@ " inputs = tokenizer(prompt, return_tensors=\"pt\")\n", " input_ids = inputs[\"input_ids\"]\n", " attention_mask = inputs[\"attention_mask\"]\n", - " decoder_input_ids = torch.tensor([[model.cfg.decoder_start_token_id]]).to(input_ids.device)\n", + " decoder_input_ids = torch.tensor([[tl_model.cfg.decoder_start_token_id]]).to(input_ids.device)\n", "\n", "\n", " while True:\n", - " logits = model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n", + " logits = tl_model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n", " # logits.shape == (batch_size (1), predicted_pos, vocab_size)\n", "\n", " token_idx = torch.argmax(logits[0, -1, :]).item()\n", @@ -160,7 +160,29 @@ " # break if End-Of-Sequence token generated\n", " if token_idx == tokenizer.eos_token_id:\n", " break\n", - " print(tl_model.generate(\"Hello my name is\"))\n", + " del tl_model\n", + " gc.collect()\n", + " if IN_COLAB:\n", + " %rm -rf /root/.cache/huggingface/hub/models*\n", + "\n", + "def run_encoder_only_set(model_set: List[str], device=\"cuda\") -> None:\n", + " for model in model_set:\n", + " print(\"Testing \" + model)\n", + " tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n", + " tl_model = HookedEncoder.from_pretrained(model, device=device)\n", + "\n", + " if GENERATE:\n", + " # Slightly adapted version of the BERT demo\n", + " prompt = \"The capital of France is [MASK].\"\n", + "\n", + " input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n", + "\n", + " logprobs = tl_model(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n", + " prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n", + "\n", + " print(f\"Prompt: {prompt}\")\n", + " print(f'Prediction: \"{prediction}\"')\n", + "\n", " del tl_model\n", " gc.collect()\n", " if IN_COLAB:\n", @@ -169,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -177,7 +199,6 @@ "free_compatible = [\n", " \"ai-forever/mGPT\",\n", " \"ArthurConmy/redwood_attn_2l\",\n", - " \"bert-base-cased\",\n", " \"bigcode/santacoder\",\n", " \"bigscience/bloom-1b1\",\n", " \"bigscience/bloom-560m\",\n", @@ -256,6 +277,10 @@ " \"Qwen/Qwen2-0.5B-Instruct\",\n", " \"Qwen/Qwen2-1.5B\",\n", " \"Qwen/Qwen2-1.5B-Instruct\",\n", + " \"Qwen/Qwen2.5-0.5B\",\n", + " \"Qwen/Qwen2.5-0.5B-Instruct\",\n", + " \"Qwen/Qwen2.5-1.5B\",\n", + " \"Qwen/Qwen2.5-1.5B-Instruct\",\n", " \"roneneldan/TinyStories-1Layer-21M\",\n", " \"roneneldan/TinyStories-1M\",\n", " \"roneneldan/TinyStories-28M\",\n", @@ -290,7 +315,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -340,6 +365,10 @@ " \"Qwen/Qwen1.5-7B-Chat\",\n", " \"Qwen/Qwen2-7B\",\n", " \"Qwen/Qwen2-7B-Instruct\",\n", + " \"Qwen/Qwen2.5-3B\",\n", + " \"Qwen/Qwen2.5-3B-Instruct\",\n", + " \"Qwen/Qwen2.5-7B\",\n", + " \"Qwen/Qwen2.5-7B-Instruct\",\n", " \"stabilityai/stablelm-base-alpha-3b\",\n", " \"stabilityai/stablelm-base-alpha-7b\",\n", " \"stabilityai/stablelm-tuned-alpha-3b\",\n", @@ -354,7 +383,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -374,6 +403,8 @@ " \"Qwen/Qwen-14B-Chat\",\n", " \"Qwen/Qwen1.5-14B\",\n", " \"Qwen/Qwen1.5-14B-Chat\",\n", + " \"Qwen/Qwen2.5-14B\",\n", + " \"Qwen/Qwen2.5-14B-Instruct\",\n", "]\n", "\n", "if IN_COLAB:\n", @@ -384,7 +415,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -398,10 +429,16 @@ " \"meta-llama/Llama-2-70b-chat-hf\",\n", " \"meta-llama/Llama-3.1-70B\",\n", " \"meta-llama/Llama-3.1-70B-Instruct\",\n", + " \"meta-llama/Llama-3.3-70B-Instruct\",\n", " \"meta-llama/Meta-Llama-3-70B\",\n", " \"meta-llama/Meta-Llama-3-70B-Instruct\",\n", " \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n", " \"mistralai/Mixtral-8x7B-v0.1\",\n", + " \"Qwen/Qwen2.5-32B\",\n", + " \"Qwen/Qwen2.5-32B-Instruct\",\n", + " \"Qwen/Qwen2.5-72B\",\n", + " \"Qwen/Qwen2.5-72B-Instruct\",\n", + " \"Qwen/QwQ-32B-Preview\",\n", "]\n", "\n", "mark_models_as_tested(incompatible_models)" @@ -409,7 +446,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -431,7 +468,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -449,7 +486,22 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# This model works on the free version of Colab\n", + "encoder_only_models = [\"bert-base-cased\"]\n", + "\n", + "if IN_COLAB:\n", + " run_encoder_only_set(encoder_only_models)\n", + "\n", + "mark_models_as_tested(encoder_only_models)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -460,7 +512,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -499,5 +551,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index c0fed32d9..41853de67 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -429,6 +429,26 @@ "cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we only wanted the layer 0 attention patterns, but we are storing the internal activations from all locations in the model. It's convenient to have access to all activations, but this can be prohibitively expensive for memory use with larger models, batch sizes, or sequence lengths. In addition, we don't need to do the full forward pass through the model to collect layer 0 attention patterns. The following cell will collect only the layer 0 attention patterns and stop the forward pass at layer 1, requiring far less memory and compute." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "attn_hook_name = \"blocks.0.attn.hook_pattern\"\n", + "attn_layer = 0\n", + "_, gpt2_attn_cache = model.run_with_cache(gpt2_tokens, remove_batch_dim=True, stop_at_layer=attn_layer + 1, names_filter=[attn_hook_name])\n", + "gpt2_attn = gpt2_attn_cache[attn_hook_name]\n", + "assert torch.equal(gpt2_attn, attention_pattern)" + ] + }, { "attachments": {}, "cell_type": "markdown", diff --git a/tests/integration/test_hooks.py b/tests/integration/test_hooks.py index 6a9880a67..29d5ff9ed 100644 --- a/tests/integration/test_hooks.py +++ b/tests/integration/test_hooks.py @@ -234,3 +234,10 @@ def set_to_randn(z, hook): # exactly when the zero hook is attached last XOR it is prepended assert torch.allclose(logits, model.unembed.b_U[None, :]) == logits_are_unembed_bias + + +def test_use_attn_in_with_gqa_raises_error(): + # Create model that uses GroupedQueryAttention + model = HookedTransformer.from_pretrained("Qwen/Qwen2-0.5B") + with pytest.raises(AssertionError): + model.set_use_attn_in(True) diff --git a/tests/unit/components/test_attention.py b/tests/unit/components/test_attention.py index c27ca6cbd..572529cec 100644 --- a/tests/unit/components/test_attention.py +++ b/tests/unit/components/test_attention.py @@ -1,3 +1,4 @@ +import einops import pytest import torch import torch.nn as nn @@ -5,6 +6,7 @@ from transformer_lens.components import Attention from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.utilities.attention import complex_attn_linear if is_bitsandbytes_available(): from bitsandbytes.nn.modules import Params4bit @@ -98,3 +100,31 @@ def test_attention_config_dict(): assert attn.cfg.load_in_4bit == False assert attn.cfg.dtype == torch.float32 assert attn.cfg.act_fn == "relu" + + +def test_remove_einsum_from_complex_attn_linear(): + batch = 64 + pos = 128 + head_index = 8 + d_model = 512 + d_head = 64 + input = torch.randn(batch, pos, head_index, d_model) + w = torch.randn(head_index, d_model, d_head) + b = torch.randn(head_index, d_head) + result_new = complex_attn_linear(input, w, b) + + # Check if new implementation without einsum produces correct shape + assert result_new.shape == (batch, pos, head_index, d_head) + + # Old implementation used einsum + result_old = ( + einops.einsum( + input, + w, + "batch pos head_index d_model, head_index d_model d_head -> batch pos head_index d_head", + ) + + b + ) + + # Check if the results are the same + assert torch.allclose(result_new, result_old, atol=1e-4) diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 6fa336c23..e2164e178 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -20,7 +20,6 @@ class first, including the examples, and then skimming the available methods. Yo import einops import numpy as np import torch -from fancy_einsum import einsum from jaxtyping import Float, Int from typing_extensions import Literal @@ -557,10 +556,8 @@ def logit_attrs( has_batch_dim=has_batch_dim, ) - logit_attrs = einsum( - "... d_model, ... d_model -> ...", scaled_residual_stack, logit_directions - ) - + # Element-wise multiplication and sum over the d_model dimension + logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1) return logit_attrs def decompose_resid( @@ -666,15 +663,22 @@ def compute_head_results( if "blocks.0.attn.hook_result" in self.cache_dict: logging.warning("Tried to compute head results when they were already cached") return - for l in range(self.model.cfg.n_layers): + for layer in range(self.model.cfg.n_layers): # Note that we haven't enabled set item on this object so we need to edit the underlying # cache_dict directly. - self.cache_dict[f"blocks.{l}.attn.hook_result"] = einsum( - "... head_index d_head, head_index d_head d_model -> ... head_index d_model", - self[("z", l, "attn")], - self.model.blocks[l].attn.W_O, + + # Add singleton dimension to match W_O's shape for broadcasting + z = einops.rearrange( + self[("z", layer, "attn")], + "... head_index d_head -> ... head_index d_head 1", ) + # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model]) + result = z * self.model.blocks[layer].attn.W_O + + # Sum over d_head to get the contribution of each head to the residual stream + self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2) + def stack_head_results( self, layer: int = -1, diff --git a/transformer_lens/SVDInterpreter.py b/transformer_lens/SVDInterpreter.py index cf0354d61..166fa36f3 100644 --- a/transformer_lens/SVDInterpreter.py +++ b/transformer_lens/SVDInterpreter.py @@ -6,7 +6,6 @@ from typing import Optional, Union -import fancy_einsum as einsum import torch from typeguard import typechecked from typing_extensions import Literal @@ -148,7 +147,7 @@ def _get_w_in_matrix(self, layer_index: int) -> torch.Tensor: if f"blocks.{layer_index}.ln2.w" in self.params: # If fold_ln == False ln_2 = self.params[f"blocks.{layer_index}.ln2.w"] - return einsum.einsum("out in, in -> out in", w_in, ln_2) + return w_in * ln_2 return w_in diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 9c4855091..009d2cfb8 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -297,17 +297,19 @@ def forward( ) ) else: + # Add singleton dimensions to make shapes compatible for broadcasting: w = einops.rearrange( self.W_O, - "head_index d_head d_model -> d_model head_index d_head", + "head_index d_head d_model -> 1 1 head_index d_head d_model", ) - result = self.hook_result( - einops.einsum( - z, - w, - "... head_index d_head, d_model head_index d_head -> ... head_index d_model", - ) - ) # [batch, pos, head_index, d_model] + z = einops.rearrange( + z, "batch pos head_index d_head -> batch pos head_index d_head 1" + ) + + # Multiply the z tensor by the W_O tensor, summing over the d_head dimension + unhooked_result = (z * w).sum(-2) + + result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model] out = ( einops.reduce(result, "batch position index model->batch position model", "sum") + self.b_O @@ -456,9 +458,16 @@ def apply_causal_mask( final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos] if attention_mask is not None: # Apply a causal mask to the attention scores considering the padding - einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos" + + # Add singleton dimensions to the attention mask to match the shape of the final mask + attention_mask = einops.rearrange( + attention_mask, "batch offset_pos -> batch 1 1 offset_pos" + ) + final_mask = final_mask.to(attention_mask.device) - final_mask = einops.einsum(final_mask, attention_mask, einsum_str).bool() + + # Element-wise multiplication of the final mask and the attention mask and cast to boolean + final_mask = (final_mask * attention_mask).bool() # [batch, head, pos, offset_pos] attn_scores = attn_scores.to(final_mask.device) return torch.where(final_mask, attn_scores, self.IGNORE) @@ -690,7 +699,11 @@ def create_alibi_bias( n_heads, device ) - # The ALiBi bias is then m * slope_matrix - alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) + # Add singleton dimensions to make shapes compatible for broadcasting: + slope = einops.rearrange(slope, "query key -> 1 query key") + multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1") + + # Element-wise multiplication of the slope and multipliers + alibi_bias = multipliers * slope return alibi_bias diff --git a/transformer_lens/components/bert_mlm_head.py b/transformer_lens/components/bert_mlm_head.py index 89a86808c..162b51899 100644 --- a/transformer_lens/components/bert_mlm_head.py +++ b/transformer_lens/components/bert_mlm_head.py @@ -4,9 +4,9 @@ """ from typing import Dict, Union +import einops import torch import torch.nn as nn -from fancy_einsum import einsum from jaxtyping import Float from transformer_lens.components import LayerNorm @@ -27,14 +27,15 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): self.ln = LayerNorm(self.cfg) def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor: - resid = ( - einsum( - "batch pos d_model_in, d_model_out d_model_in -> batch pos d_model_out", - resid, - self.W, - ) - + self.b - ) + # Add singleton dimension for broadcasting + resid = einops.rearrange(resid, "batch pos d_model_in -> batch pos 1 d_model_in") + + # Element-wise multiplication of W and resid + resid = resid * self.W + + # Sum over d_model_in dimension and add bias + resid = resid.sum(-1) + self.b + resid = self.act_fn(resid) resid = self.ln(resid) return resid diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index ec718810a..4bda8ccb8 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -439,7 +439,8 @@ def run_with_hooks( clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is False. *model_args: Positional arguments for the model. - **model_kwargs: Keyword arguments for the model. + **model_kwargs: Keyword arguments for the model's forward function. See your related + models forward pass for details as to what sort of arguments you can pass through. Note: If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks @@ -540,7 +541,8 @@ def run_with_cache( Defaults to False. pos_slice: The slice to apply to the cache output. Defaults to None, do nothing. - **model_kwargs: Keyword arguments for the model. + **model_kwargs: Keyword arguments for the model's forward function. See your related + models forward pass for details as to what sort of arguments you can pass through. Returns: tuple: A tuple containing the model output and a Cache object. diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index aa544786f..17d32e8c7 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -151,14 +151,15 @@ "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-70B", "meta-llama/Meta-Llama-3-70B-Instruct", - "meta-llama/Llama-3.2-1B", - "meta-llama/Llama-3.2-3B", - "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.1-70B", "meta-llama/Llama-3.1-8B", "meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.2-3B", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-3B-Instruct", + "meta-llama/Llama-3.3-70B-Instruct", "Baidicoot/Othello-GPT-Transformer-Lens", "bert-base-cased", "roneneldan/TinyStories-1M", @@ -212,6 +213,21 @@ "Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-7B", "Qwen/Qwen2-7B-Instruct", + "Qwen/Qwen2.5-0.5B", + "Qwen/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen2.5-1.5B", + "Qwen/Qwen2.5-1.5B-Instruct", + "Qwen/Qwen2.5-3B", + "Qwen/Qwen2.5-3B-Instruct", + "Qwen/Qwen2.5-7B", + "Qwen/Qwen2.5-7B-Instruct", + "Qwen/Qwen2.5-14B", + "Qwen/Qwen2.5-14B-Instruct", + "Qwen/Qwen2.5-32B", + "Qwen/Qwen2.5-32B-Instruct", + "Qwen/Qwen2.5-72B", + "Qwen/Qwen2.5-72B-Instruct", + "Qwen/QwQ-32B-Preview", "microsoft/phi-1", "microsoft/phi-1_5", "microsoft/phi-2", @@ -945,6 +961,30 @@ def convert_hf_model_config(model_name: str, **kwargs): "NTK_by_parts_high_freq_factor": 4.0, "NTK_by_parts_factor": 32.0, } + elif "Llama-3.3-70B" in official_model_name: + cfg_dict = { + "d_model": 8192, + "d_head": 128, + "n_heads": 64, + "d_mlp": 28672, + "n_layers": 80, + "n_ctx": 2048, # capped due to memory issues + "eps": 1e-5, + "d_vocab": 128256, + "act_fn": "silu", + "n_key_value_heads": 8, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, + "rotary_dim": 32, + "final_rms": True, + "gated_mlp": True, + "rotary_base": 500000.0, + "use_NTK_by_parts_rope": True, + "NTK_by_parts_low_freq_factor": 1.0, + "NTK_by_parts_high_freq_factor": 4.0, + "NTK_by_parts_factor": 8.0, + } elif "Llama-3.1-8B" in official_model_name: cfg_dict = { "d_model": 4096, @@ -1158,6 +1198,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "normalization_type": "LN", "post_embedding_ln": True, "positional_embedding_type": "alibi", + "default_prepend_bos": False, } elif architecture == "GPT2LMHeadCustomModel": # santacoder @@ -1225,6 +1266,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "trust_remote_code": True, "final_rms": True, "gated_mlp": True, + "default_prepend_bos": False, } elif architecture == "Qwen2ForCausalLM": # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. @@ -1243,12 +1285,13 @@ def convert_hf_model_config(model_name: str, **kwargs): "initializer_range": hf_config.initializer_range, "normalization_type": "RMS", "positional_embedding_type": "rotary", - "rotary_base": hf_config.rope_theta, + "rotary_base": int(hf_config.rope_theta), "rotary_adjacent_pairs": False, "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, "tokenizer_prepends_bos": True, "final_rms": True, "gated_mlp": True, + "default_prepend_bos": False, } elif architecture == "PhiForCausalLM": # Architecture for microsoft/phi models @@ -1309,7 +1352,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "act_fn": "gelu_new", "initializer_range": 0.02, "normalization_type": "RMS", - "rotary_base": 10000.0, + "rotary_base": 10000, "rotary_dim": 256, "positional_embedding_type": "rotary", "use_attn_scale": True, diff --git a/transformer_lens/utilities/attention.py b/transformer_lens/utilities/attention.py index dc38bde99..3deb23a19 100644 --- a/transformer_lens/utilities/attention.py +++ b/transformer_lens/utilities/attention.py @@ -2,6 +2,7 @@ Utilities for attention components. """ + import einops import torch import torch.nn.functional as F @@ -28,11 +29,14 @@ def complex_attn_linear( This is almost the same as simple_attn_linear, but the input tensor has an extra head_index dimension, used when calculating the input of each attention head separately. """ - return ( - einops.einsum( - input, - w, - "batch pos head_index d_model, head_index d_model d_head -> batch pos head_index d_head", - ) - + b + + # Add singleton dimensions for broadcasting + input = einops.rearrange( + input, "batch pos head_index d_model -> batch pos head_index d_model 1" ) + w = einops.rearrange(w, "head_index d_model d_head -> 1 1 head_index d_model d_head") + + # Element-wise multiplication and sum over the d_model dimension + result = input * w + result = result.sum(dim=-2) + return result + b