diff --git a/.github/ISSUE_TEMPLATE/compatibility.md b/.github/ISSUE_TEMPLATE/compatibility.md new file mode 100644 index 000000000..60a4632c9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/compatibility.md @@ -0,0 +1,35 @@ +--- +name: Compatibility Report +about: Submit a compatibility report +title: "[Compatibility Report] Model ID" + +--- + + + +## Model + +REPLACE_WITH_MODEL_ID + +- [ ] This model was incompatible when it was introduced to TransformerLens + + + +The model seems to have worked as of REPLACE_WITH_LAST_COMPATIBLE_VERSION_NUMBER. It first started +showing signs of incompatibility in REPLACE_WITH_FIRST_INCOMPATIBLE_VERSION_NUMBER. + +### Example of some generations in transformers + + +### Code used to load the model in TransformerLens + + +### Example of some generations in TransformerLens diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 1b71d373e..fb686122d 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -123,6 +123,7 @@ jobs: # - "Activation_Patching_in_TL_Demo" # - "Attribution_Patching_Demo" - "ARENA_Content" + - "Colab_Compatibility" - "BERT" - "Exploratory_Analysis_Demo" # - "Grokking_Demo" diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb new file mode 100644 index 000000000..62b5814d4 --- /dev/null +++ b/demos/Colab_Compatibility.ipynb @@ -0,0 +1,503 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + " ipython.magic(\"load_ext autoreload\")\n", + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + " ipython.magic(\"autoreload 2\")\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")\n", + " \n", + "\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " # %pip install sentencepiece # Llama tokenizer requires sentencepiece\n", + " %pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n", + " %pip install torch\n", + " %pip install tiktoken\n", + " %pip install transformer_lens\n", + " %pip install transformers_stream_generator\n", + " # !huggingface-cli login --token NEEL'S TOKEN" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TransformerLens currently supports 190 models out of the box.\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "from transformer_lens import HookedTransformer, HookedEncoderDecoder, loading\n", + "from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer\n", + "from typing import List\n", + "import gc\n", + "\n", + "untested_models = []\n", + "untested_models.extend(loading.OFFICIAL_MODEL_NAMES)\n", + "\n", + "print(\"TransformerLens currently supports \" + str(len(untested_models)) + \" models out of the box.\")\n", + "\n", + "GENERATE = True\n", + "# Fill this in if you have llama weights uploaded, and you with to test those models\n", + "LLAMA_MODEL_PATH = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def mark_models_as_tested(model_set: List[str]) -> None:\n", + " for model in model_set:\n", + " untested_models.remove(model)\n", + " \n", + "\n", + "def run_set(model_set: List[str], device=\"cuda\") -> None:\n", + " for model in model_set:\n", + " print(\"Testing \" + model)\n", + " tl_model = HookedTransformer.from_pretrained_no_processing(model, device=device)\n", + " if GENERATE:\n", + " print(tl_model.generate(\"Hello my name is\"))\n", + " del tl_model\n", + " gc.collect()\n", + " if IN_COLAB:\n", + " %rm -rf /root/.cache/huggingface/hub/models*\n", + "\n", + "def run_llama_set(model_set: List[str], weight_root: str, device=\"cuda\") -> None:\n", + " for model in model_set:\n", + " print(\"Testing \" + model)\n", + " # to run this, make sure weight root is the root that contains all models with the \n", + " # sub directories sharing the same name as the model in the list of models\n", + " tokenizer = LlamaTokenizer.from_pretrained(weight_root + model)\n", + " hf_model = LlamaForCausalLM.from_pretrained(weight_root + model, low_cpu_mem_usage=True)\n", + " tl_model = HookedTransformer.from_pretrained_no_processing(\n", + " model, \n", + " hf_model=hf_model,\n", + " device=device,\n", + " fold_ln=False,\n", + " center_writing_weights=False,\n", + " center_unembed=False,\n", + " tokenizer=tokenizer,\n", + " )\n", + " if GENERATE:\n", + " print(tl_model.generate(\"Hello my name is\"))\n", + " del tl_model\n", + " gc.collect()\n", + " if IN_COLAB:\n", + " %rm -rf /root/.cache/huggingface/hub/models*\n", + "\n", + "\n", + "def run_encoder_decoder_set(model_set: List[str], device=\"cuda\") -> None:\n", + " for model in model_set:\n", + " print(\"Testing \" + model)\n", + " tokenizer = AutoTokenizer.from_pretrained(model)\n", + " tl_model = HookedEncoderDecoder.from_pretrained(model, device=device)\n", + " if GENERATE:\n", + " # Originally from the t5 demo\n", + " prompt = \"Hello, how are you? \"\n", + " inputs = tokenizer(prompt, return_tensors=\"pt\")\n", + " input_ids = inputs[\"input_ids\"]\n", + " attention_mask = inputs[\"attention_mask\"]\n", + " decoder_input_ids = torch.tensor([[model.cfg.decoder_start_token_id]]).to(input_ids.device)\n", + "\n", + "\n", + " while True:\n", + " logits = model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n", + " # logits.shape == (batch_size (1), predicted_pos, vocab_size)\n", + "\n", + " token_idx = torch.argmax(logits[0, -1, :]).item()\n", + " print(\"generated token: \\\"\", tokenizer.decode(token_idx), \"\\\", token id: \", token_idx, sep=\"\")\n", + "\n", + " # append token to decoder_input_ids\n", + " decoder_input_ids = torch.cat([decoder_input_ids, torch.tensor([[token_idx]]).to(input_ids.device)], dim=-1)\n", + "\n", + " # break if End-Of-Sequence token generated\n", + " if token_idx == tokenizer.eos_token_id:\n", + " break\n", + " print(tl_model.generate(\"Hello my name is\"))\n", + " del tl_model\n", + " gc.collect()\n", + " if IN_COLAB:\n", + " %rm -rf /root/.cache/huggingface/hub/models*" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# The following models can run in the T4 free environment\n", + "free_compatible = [\n", + " \"ai-forever/mGPT\",\n", + " \"ArthurConmy/redwood_attn_2l\",\n", + " \"bert-base-cased\",\n", + " \"bigcode/santacoder\",\n", + " \"bigscience/bloom-1b1\",\n", + " \"bigscience/bloom-560m\",\n", + " \"distilgpt2\",\n", + " \"EleutherAI/gpt-neo-1.3B\",\n", + " \"EleutherAI/gpt-neo-125M\",\n", + " \"EleutherAI/gpt-neo-2.7B\",\n", + " \"EleutherAI/pythia-1.4b\",\n", + " \"EleutherAI/pythia-1.4b-deduped\",\n", + " \"EleutherAI/pythia-1.4b-deduped-v0\",\n", + " \"EleutherAI/pythia-1.4b-v0\",\n", + " \"EleutherAI/pythia-14m\",\n", + " \"EleutherAI/pythia-160m\",\n", + " \"EleutherAI/pythia-160m-deduped\",\n", + " \"EleutherAI/pythia-160m-deduped-v0\",\n", + " \"EleutherAI/pythia-160m-seed1\",\n", + " \"EleutherAI/pythia-160m-seed2\",\n", + " \"EleutherAI/pythia-160m-seed3\",\n", + " \"EleutherAI/pythia-160m-v0\",\n", + " \"EleutherAI/pythia-1b\",\n", + " \"EleutherAI/pythia-1b-deduped\",\n", + " \"EleutherAI/pythia-1b-deduped-v0\",\n", + " \"EleutherAI/pythia-1b-v0\",\n", + " \"EleutherAI/pythia-31m\",\n", + " \"EleutherAI/pythia-410m\",\n", + " \"EleutherAI/pythia-410m-deduped\",\n", + " \"EleutherAI/pythia-410m-deduped-v0\",\n", + " \"EleutherAI/pythia-410m-v0\",\n", + " \"EleutherAI/pythia-70m\",\n", + " \"EleutherAI/pythia-70m-deduped\",\n", + " \"EleutherAI/pythia-70m-deduped-v0\",\n", + " \"EleutherAI/pythia-70m-v0\",\n", + " \"facebook/opt-1.3b\",\n", + " \"facebook/opt-125m\",\n", + " \"gpt2\",\n", + " \"gpt2-large\",\n", + " \"gpt2-medium\",\n", + " \"gpt2-xl\",\n", + " \"meta-llama/Llama-3.2-1B\",\n", + " \"meta-llama/Llama-3.2-1B-Instruct\",\n", + " \"microsoft/phi-1\",\n", + " \"microsoft/phi-1_5\",\n", + " \"NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr\",\n", + " \"NeelNanda/Attn_Only_1L512W_C4_Code\",\n", + " \"NeelNanda/Attn_Only_2L512W_C4_Code\",\n", + " \"NeelNanda/Attn_Only_3L512W_C4_Code\",\n", + " \"NeelNanda/Attn_Only_4L512W_C4_Code\",\n", + " \"NeelNanda/GELU_1L512W_C4_Code\",\n", + " \"NeelNanda/GELU_2L512W_C4_Code\",\n", + " \"NeelNanda/GELU_3L512W_C4_Code\",\n", + " \"NeelNanda/GELU_4L512W_C4_Code\",\n", + " \"NeelNanda/SoLU_10L1280W_C4_Code\",\n", + " \"NeelNanda/SoLU_10L_v22_old\",\n", + " \"NeelNanda/SoLU_12L1536W_C4_Code\",\n", + " \"NeelNanda/SoLU_12L_v23_old\",\n", + " \"NeelNanda/SoLU_1L512W_C4_Code\",\n", + " \"NeelNanda/SoLU_1L512W_Wiki_Finetune\",\n", + " \"NeelNanda/SoLU_1L_v9_old\",\n", + " \"NeelNanda/SoLU_2L512W_C4_Code\",\n", + " \"NeelNanda/SoLU_2L_v10_old\",\n", + " \"NeelNanda/SoLU_3L512W_C4_Code\",\n", + " \"NeelNanda/SoLU_4L512W_C4_Code\",\n", + " \"NeelNanda/SoLU_4L512W_Wiki_Finetune\",\n", + " \"NeelNanda/SoLU_4L_v11_old\",\n", + " \"NeelNanda/SoLU_6L768W_C4_Code\",\n", + " \"NeelNanda/SoLU_6L_v13_old\",\n", + " \"NeelNanda/SoLU_8L1024W_C4_Code\",\n", + " \"NeelNanda/SoLU_8L_v21_old\",\n", + " \"Qwen/Qwen-1_8B\",\n", + " \"Qwen/Qwen-1_8B-Chat\",\n", + " \"Qwen/Qwen1.5-0.5B\",\n", + " \"Qwen/Qwen1.5-0.5B-Chat\",\n", + " \"Qwen/Qwen1.5-1.8B\",\n", + " \"Qwen/Qwen1.5-1.8B-Chat\",\n", + " \"Qwen/Qwen2-0.5B\",\n", + " \"Qwen/Qwen2-0.5B-Instruct\",\n", + " \"Qwen/Qwen2-1.5B\",\n", + " \"Qwen/Qwen2-1.5B-Instruct\",\n", + " \"roneneldan/TinyStories-1Layer-21M\",\n", + " \"roneneldan/TinyStories-1M\",\n", + " \"roneneldan/TinyStories-28M\",\n", + " \"roneneldan/TinyStories-2Layers-33M\",\n", + " \"roneneldan/TinyStories-33M\",\n", + " \"roneneldan/TinyStories-3M\",\n", + " \"roneneldan/TinyStories-8M\",\n", + " \"roneneldan/TinyStories-Instruct-1M\",\n", + " \"roneneldan/TinyStories-Instruct-28M\",\n", + " \"roneneldan/TinyStories-Instruct-2Layers-33M\",\n", + " \"roneneldan/TinyStories-Instruct-33M\",\n", + " \"roneneldan/TinyStories-Instruct-3M\",\n", + " \"roneneldan/TinyStories-Instruct-8M\",\n", + " \"roneneldan/TinyStories-Instuct-1Layer-21M\",\n", + " \"stanford-crfm/alias-gpt2-small-x21\",\n", + " \"stanford-crfm/arwen-gpt2-medium-x21\",\n", + " \"stanford-crfm/battlestar-gpt2-small-x49\",\n", + " \"stanford-crfm/beren-gpt2-medium-x49\",\n", + " \"stanford-crfm/caprica-gpt2-small-x81\",\n", + " \"stanford-crfm/celebrimbor-gpt2-medium-x81\",\n", + " \"stanford-crfm/darkmatter-gpt2-small-x343\",\n", + " \"stanford-crfm/durin-gpt2-medium-x343\",\n", + " \"stanford-crfm/eowyn-gpt2-medium-x777\",\n", + " \"stanford-crfm/expanse-gpt2-small-x777\",\n", + "]\n", + "\n", + "if IN_COLAB:\n", + " run_set(free_compatible)\n", + " \n", + "mark_models_as_tested(free_compatible)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "paid_gpu_models = [\n", + " \"01-ai/Yi-6B\",\n", + " \"01-ai/Yi-6B-Chat\",\n", + " \"bigscience/bloom-1b7\",\n", + " \"bigscience/bloom-3b\",\n", + " \"bigscience/bloom-7b1\",\n", + " \"codellama/CodeLlama-7b-hf\",\n", + " \"codellama/CodeLlama-7b-Instruct-hf\",\n", + " \"codellama/CodeLlama-7b-Python-hf\",\n", + " \"EleutherAI/pythia-2.8b\",\n", + " \"EleutherAI/pythia-2.8b-deduped\",\n", + " \"EleutherAI/pythia-2.8b-deduped-v0\",\n", + " \"EleutherAI/pythia-2.8b-v0\",\n", + " \"EleutherAI/pythia-6.9b\",\n", + " \"EleutherAI/pythia-6.9b-deduped\",\n", + " \"EleutherAI/pythia-6.9b-deduped-v0\",\n", + " \"EleutherAI/pythia-6.9b-v0\",\n", + " \"facebook/opt-2.7b\",\n", + " \"facebook/opt-6.7b\",\n", + " \"google/gemma-2-2b\",\n", + " \"google/gemma-2-2b-it\",\n", + " \"google/gemma-2b\",\n", + " \"google/gemma-2b-it\",\n", + " \"google/gemma-7b\",\n", + " \"google/gemma-7b-it\",\n", + " \"meta-llama/Llama-2-7b-chat-hf\",\n", + " \"meta-llama/Llama-2-7b-hf\",\n", + " \"meta-llama/Llama-3.1-8B\",\n", + " \"meta-llama/Llama-3.1-8B-Instruct\",\n", + " \"meta-llama/Llama-3.2-3B\",\n", + " \"meta-llama/Llama-3.2-3B-Instruct\",\n", + " \"meta-llama/Meta-Llama-3-8B\",\n", + " \"meta-llama/Meta-Llama-3-8B-Instruct\",\n", + " \"microsoft/phi-2\",\n", + " \"microsoft/Phi-3-mini-4k-instruct\",\n", + " \"mistralai/Mistral-7B-Instruct-v0.1\",\n", + " \"mistralai/Mistral-7B-v0.1\",\n", + " \"mistralai/Mistral-Nemo-Base-2407\",\n", + " \"Qwen/Qwen-7B\",\n", + " \"Qwen/Qwen-7B-Chat\",\n", + " \"Qwen/Qwen1.5-4B\",\n", + " \"Qwen/Qwen1.5-4B-Chat\",\n", + " \"Qwen/Qwen1.5-7B\",\n", + " \"Qwen/Qwen1.5-7B-Chat\",\n", + " \"Qwen/Qwen2-7B\",\n", + " \"Qwen/Qwen2-7B-Instruct\",\n", + " \"stabilityai/stablelm-base-alpha-3b\",\n", + " \"stabilityai/stablelm-base-alpha-7b\",\n", + " \"stabilityai/stablelm-tuned-alpha-3b\",\n", + " \"stabilityai/stablelm-tuned-alpha-7b\",\n", + "]\n", + "\n", + "if IN_COLAB:\n", + " run_set(paid_gpu_models)\n", + " \n", + "mark_models_as_tested(paid_gpu_models)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "paid_cpu_models = [\n", + " \"EleutherAI/gpt-j-6B\",\n", + " \"EleutherAI/gpt-neox-20b\",\n", + " \"EleutherAI/pythia-12b\",\n", + " \"EleutherAI/pythia-12b-deduped\",\n", + " \"EleutherAI/pythia-12b-deduped-v0\",\n", + " \"EleutherAI/pythia-12b-v0\",\n", + " \"facebook/opt-13b\",\n", + " \"google/gemma-2-9b\",\n", + " \"google/gemma-2-9b-it\",\n", + " \"meta-llama/Llama-2-13b-chat-hf\",\n", + " \"meta-llama/Llama-2-13b-hf\",\n", + " \"Qwen/Qwen-14B\",\n", + " \"Qwen/Qwen-14B-Chat\",\n", + " \"Qwen/Qwen1.5-14B\",\n", + " \"Qwen/Qwen1.5-14B-Chat\",\n", + "]\n", + "\n", + "if IN_COLAB:\n", + " run_set(paid_cpu_models, \"cpu\")\n", + " \n", + "mark_models_as_tested(paid_cpu_models)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "incompatible_models = [\n", + " \"01-ai/Yi-34B\",\n", + " \"01-ai/Yi-34B-Chat\",\n", + " \"facebook/opt-30b\",\n", + " \"facebook/opt-66b\",\n", + " \"google/gemma-2-27b\",\n", + " \"google/gemma-2-27b-it\",\n", + " \"meta-llama/Llama-2-70b-chat-hf\",\n", + " \"meta-llama/Llama-3.1-70B\",\n", + " \"meta-llama/Llama-3.1-70B-Instruct\",\n", + " \"meta-llama/Meta-Llama-3-70B\",\n", + " \"meta-llama/Meta-Llama-3-70B-Instruct\",\n", + " \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n", + " \"mistralai/Mixtral-8x7B-v0.1\",\n", + "]\n", + "\n", + "mark_models_as_tested(incompatible_models)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "# The following models take a few extra steps to function. Check the official demo for more\n", + "# information on how to use. 7b and 13b will work in the paid environment. 30b and 65b will not work\n", + "# in Colab\n", + "not_hosted_models = [\n", + " \"llama-7b-hf\",\n", + " \"llama-13b-hf\",\n", + " \"llama-30b-hf\",\n", + " \"llama-65b-hf\",\n", + "]\n", + "\n", + "if LLAMA_MODEL_PATH:\n", + " run_llama_set(not_hosted_models, LLAMA_MODEL_PATH)\n", + "\n", + "mark_models_as_tested(not_hosted_models)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# These all work on the free version of Colab\n", + "encoder_decoders = [\n", + " \"google-t5/t5-base\",\n", + " \"google-t5/t5-large\",\n", + " \"google-t5/t5-small\",\n", + "]\n", + "if IN_COLAB:\n", + " run_encoder_decoder_set(encoder_decoders)\n", + "\n", + "mark_models_as_tested(encoder_decoders)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "broken_models = [\n", + " \"Baidicoot/Othello-GPT-Transformer-Lens\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Baidicoot/Othello-GPT-Transformer-Lens\n" + ] + } + ], + "source": [ + "# Any models listed in the cell below have not been tested. This should always remain blank. If your\n", + "# PR fails due to this notebook, most likely you need to check any new model changes to ensure that\n", + "# this notebook is up to date.\n", + "print(*untested_models, sep = '\\n')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/_static/TransformerLens_Diagram.svg b/docs/source/_static/TransformerLens_Diagram.svg new file mode 100644 index 000000000..fb7a5c65d --- /dev/null +++ b/docs/source/_static/TransformerLens_Diagram.svg @@ -0,0 +1,12396 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + m = d_model + m = d_model + m = d_model + n = n_ctx + m = d_model + m = d_head + n = d_model + n_head + n = d_vocab + m = 1 + n = seq_len + n = seq_len + seq_len + n = seq_len + InputTokens + W_positional + (GPT-2) + W_Query + W_Embedding + (Input vectors) + hook_resid_pre + r e s i d u a l s t r e a m + + = + = + + + + + X + + + ( + [ + [ + ) + ) + ( + + + + m = d_head + n = 1 + n_head + b_Query + + + + + + + + + + + + + + + + + + + m = len_seq + n = len_seq + n_head + + + + + + + + + + + + + + m = d_head + n = d_model + n_head + W_Key + + + X + X + + + + + + + + + m = d_head + n = 1 + n_head + b_Key + + + + m = d_head + n = d_model + n_head + W_Value + + + + + X + + + + m = d_head + n = 1 + n_head + b_Value + + + + + + + + + m = d_model + n = d_head + n_head + W_Output + + + + + + + + + + + + + + + + softmax + + + softmax + + + ∑heads + b_Output + attn.hook_q + RoPE + (GPT-J) + + + RoPE + mask + Q + K + T + (GPT-J) + + + index + + + + + + + + + hook_embed + + + + + + + m = d_head + n = len_seq + n_head + (query vectors) + (scaled attention) + (attention heads) + + + + = + = + = + = + attn.hook_k + hook_attn_scores + attn. + attn.hook_pattern + + + + + + m = d_head + d_head + n = len_seq + n_head + (key vectors) + + + + + + ln2.hook_normalized + n = seq_len + m = d_model + = + attn.hook_v + + + + + + m = d_head + n = len_seq + n_head + (value vectors) + + + + + + + + + + + + m = len_seq + n = len_seq + n_head + + + + X + V + attn.hook_z + + + + + + m = d_head + n = len_seq + n_head + (weighted values) + (attention output) + + + + + + + + X + + + + + + + + + + + m = 4 x d_model + m = 4 x d_model + *or other nonlinear fn + n = d_model + W_inmlp + b_inmlp + + + X + + + + + + + + + + [d_model] + + + + + m = d_model + n = seq_len + hook_attn_out + + + m = 4 x d_model + n = seq_len + mlp.hook_pre + + + m = 4 x d_model + n = seq_len + mlp.hook_post + hook_attn_out + + + m = d_model + n = seq_len + hook_resid_mid + hook_resid_mid + + + = + = + + + + + + + – mean) / std + = + + + ln_final.hook_normalized + n = seq_len + m = d_model + [d_model] + + + + + – mean) / std + ln_final.b + = + + GeLU* + m = d_model + m = d_model + n = 4 x d_model + W_outmlp + b_outmlp + + + X + + + + + m = d_model + n = seq_len + n = seq_len + mlp.hook_out + + + hook_resid_post + + + m = d_model + + + + (not final layer) + + + + + + + + + + + + (final layer) + m = d_vocab + m = d_vocab + m = d_vocab + n = seq_len + n = d_model + W_Unembed + b_Unembed + = + + + X + + + + + logits + probabilities + + [ + [ + [ + + + ln1.hook_normalized + n = seq_len + m = d_model + [d_model] + + + X + + + – mean) / std + ln1.w + ln1.b + = + = + + + + + + + + + + + r e s i d u a l s t r e a m + + Weight_Matrices in sans serif + activation tensors in times new roman + [ + Diagram of GPT-2 style LLMin TransformerLens notation + attn + mlp + encoding + + mlp.hook_out + + + = + + [ + [ + + + = + X + ln_final.w + [ + [ + [ + [ + [ + [ + [ + [ + [d_model] + + + X + ln2.w + ln2.b + [ + [ + [ + [ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + hook_resid_pre + + + + + + diff --git a/docs/source/index.md b/docs/source/index.md index f1b8737d5..4851b4334 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -18,6 +18,8 @@ I used to work for the [Anthropic interpretability team](https://transformer-cir The core features were heavily inspired by the interface to [Anthropic's excellent Garcon tool](https://transformer-circuits.pub/2021/garcon/index.html). Credit to Nelson Elhage and Chris Olah for building Garcon and showing me the value of good infrastructure for enabling exploratory research! +A great place to start is to take a look at a helpful diagram of [all weight matrices and activation tensors with TransformerLens notation](_static/TransformerLens_Diagram.svg) courtesy of [Austin Kozlowski](https://github.com/akozlo). Another helpful tool to help you get going as quickly as possible is our [Colab Compatability Demo](https://github.com/TransformerLensOrg/TransformerLens/tree/main/demos/Colab_Compatibility.ipynb), which will give you a good idea of what you can do in various Colab environments. + ```{toctree} :hidden: :caption: Introduction diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 3fdd1c1ed..56096484c 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1220,6 +1220,10 @@ def from_pretrained( "right". first_n_layers: If specified, only load the first n layers of the model. """ + if model_name.lower().startswith("t5"): + raise RuntimeError( + "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." + ) assert not ( from_pretrained_kwargs.get("load_in_8bit", False) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index cc0295323..0b8489976 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -144,9 +144,9 @@ "meta-llama/Llama-2-13b-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", - "CodeLlama-7b-hf", - "CodeLlama-7b-Python-hf", - "CodeLlama-7b-Instruct-hf", + "codellama/CodeLlama-7b-hf", + "codellama/CodeLlama-7b-Python-hf", + "codellama/CodeLlama-7b-Instruct-hf", "meta-llama/Meta-Llama-3-8B", "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-70B", @@ -155,6 +155,10 @@ "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", + "meta-llama/Llama-3.1-70B", + "meta-llama/Llama-3.1-8B", + "meta-llama/Llama-3.1-8B-Instruct", + "meta-llama/Llama-3.1-70B-Instruct", "Baidicoot/Othello-GPT-Transformer-Lens", "bert-base-cased", "roneneldan/TinyStories-1M", @@ -177,6 +181,7 @@ "stabilityai/stablelm-tuned-alpha-7b", "mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mistral-Nemo-Base-2407", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1", "bigscience/bloom-560m", @@ -562,12 +567,12 @@ "meta-llama/Llama-2-13b-chat-hf", ], "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"], - "CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], - "CodeLlama-7b-Python-hf": [ + "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], + "codellama/CodeLlama-7b-Python-hf": [ "CodeLlama-7b-python", "codellama/CodeLlama-7b-Python-hf", ], - "CodeLlama-7b-Instruct-hf": [ + "codellama/CodeLlama-7b-Instruct-hf": [ "CodeLlama-7b-instruct", "codellama/CodeLlama-7b-Instruct-hf", ], @@ -604,6 +609,7 @@ ], "mistralai/Mistral-7B-v0.1": ["mistral-7b"], "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], + "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"], "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], "mistralai/Mixtral-8x7B-Instruct-v0.1": [ "mixtral-instruct", @@ -755,7 +761,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } - elif official_model_name.startswith("CodeLlama-7b"): # same architecture CodeLlama and Llama-2 + elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 cfg_dict = { "d_model": 4096, "d_head": 4096 // 32, @@ -927,13 +933,13 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } - elif "Llama-3.2-1B-Instruct" in official_model_name: + elif "Llama-3.1-8B" in official_model_name: cfg_dict = { - "d_model": 2048, - "d_head": 64, + "d_model": 4096, + "d_head": 128, "n_heads": 32, - "d_mlp": 8192, - "n_layers": 16, + "d_mlp": 14336, + "n_layers": 32, "n_ctx": 2048, # capped due to memory issues "eps": 1e-5, "d_vocab": 128256, @@ -942,17 +948,17 @@ def convert_hf_model_config(model_name: str, **kwargs): "normalization_type": "RMS", "positional_embedding_type": "rotary", "rotary_adjacent_pairs": False, - "rotary_dim": 64, + "rotary_dim": 128, "final_rms": True, "gated_mlp": True, } - elif "Llama-3.2-3B-Instruct" in official_model_name: + elif "Llama-3.1-70B" in official_model_name: cfg_dict = { - "d_model": 3072, + "d_model": 8192, "d_head": 128, - "n_heads": 24, - "d_mlp": 8192, - "n_layers": 28, + "n_heads": 64, + "d_mlp": 28672, + "n_layers": 80, "n_ctx": 2048, # capped due to memory issues "eps": 1e-5, "d_vocab": 128256, @@ -1070,24 +1076,27 @@ def convert_hf_model_config(model_name: str, **kwargs): "attention_dir": "bidirectional", } elif architecture == "MistralForCausalLM": + use_local_attn = True if hf_config.sliding_window else False cfg_dict = { - "d_model": 4096, - "d_head": 4096 // 32, - "n_heads": 32, - "d_mlp": 14336, - "n_layers": 32, + "d_model": hf_config.hidden_size, + "d_head": hf_config.head_dim + if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0 + else hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, "n_ctx": 2048, # Capped due to memory issues - "d_vocab": 32000, - "act_fn": "silu", + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "window_size": hf_config.sliding_window, # None if no sliding window was used + "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, + "eps": hf_config.rms_norm_eps, + "rotary_base": hf_config.rope_theta, + "n_key_value_heads": hf_config.num_key_value_heads, + "use_local_attn": use_local_attn, "normalization_type": "RMS", "positional_embedding_type": "rotary", - "window_size": 4096, - "attn_types": ["local"] * 32, - "eps": 1e-05, - "n_key_value_heads": 8, "gated_mlp": True, - "use_local_attn": True, - "rotary_dim": 4096 // 32, } elif architecture == "MixtralForCausalLM": cfg_dict = {