From 649d3be19b0f7283fd81d06e4f94aef8cb6b2cfe Mon Sep 17 00:00:00 2001
From: Neel Nanda <77788841+neelnanda-io@users.noreply.github.com>
Date: Sat, 4 Feb 2023 18:03:34 +0000
Subject: [PATCH] Added Utilities for Activation Patching + A Demo of how to
use them (#165)
* Patching utils
* Adding Activation Patching utils and a demo
---
activation_patching_in_TL_demo.py.ipynb | 1 +
transformer_lens/ActivationCache.py | 29 +-
transformer_lens/__init__.py | 1 +
transformer_lens/patching.py | 380 ++++++++++++++++++++++++
4 files changed, 409 insertions(+), 2 deletions(-)
create mode 100644 activation_patching_in_TL_demo.py.ipynb
create mode 100644 transformer_lens/patching.py
diff --git a/activation_patching_in_TL_demo.py.ipynb b/activation_patching_in_TL_demo.py.ipynb
new file mode 100644
index 000000000..a27b3e3d5
--- /dev/null
+++ b/activation_patching_in_TL_demo.py.ipynb
@@ -0,0 +1 @@
+{"cells":[{"cell_type":"markdown","metadata":{},"source":[" # Activation Patching in TransformerLens Demo\n"," This is an accompaniment to [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo). That notebook explains some basic techniques for mech interp of networks, including an overview of activation patching ([summary here](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx)). This demonstrates how to use the Activation Patching utils in TransformerLens.\n"]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","DEBUG_MODE = False\n","try:\n"," import google.colab\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n"," %pip install transformer_lens\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"png\""]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional\n","from functools import partial\n","import copy\n","\n","import itertools\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["import pysvelte\n","\n","import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"]},{"cell_type":"markdown","metadata":{},"source":[" We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training."]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[{"data":{"text/plain":[""]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["torch.set_grad_enabled(False)"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## Activation Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n","answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]\n"," ]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape)==3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE - CORRUPTED_BASELINE)\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we use the patching module to call activation patching utilities"]},{"cell_type":"code","execution_count":12,"metadata":{},"outputs":[],"source":["# Whether to do the runs by head and by position, which are much slower\n","DO_SLOW_RUNS = False"]},{"cell_type":"markdown","metadata":{},"source":[" ### Patching Single Activation Types\n"," We start by patching single types of activation\n"," The general syntax is that the functions are called get_act_patch_... and take in (model, corrupted_tokens, clean_cache, patching_metric)"]},{"cell_type":"markdown","metadata":{},"source":[" We can patch the residual stream at the start of each block over each layer and position\n"," resid_pre -> attn_out, mlp_out, resid_mid all also work"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"3f4b44886e5a47b1b3f3b51ec96e9b85","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(resid_pre_act_patch_results, \n"," yaxis=\"Layer\", \n"," xaxis=\"Position\", \n"," x=[f\"{tok} {i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," title=\"resid_pre Activation Patching\")"]},{"cell_type":"markdown","metadata":{},"source":[" We can patch head outputs over each head in each layer, patching across all positions at once\n"," out -> q, k, v, pattern all also work"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a99b6bfcba454ea1b862da644c1f3154","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(attn_head_out_all_pos_act_patch_results, \n"," yaxis=\"Layer\", \n"," xaxis=\"Head\", \n"," title=\"attn_head_out Activation Patching (All Pos)\")"]},{"cell_type":"markdown","metadata":{},"source":[" We can patch head outputs over each head in each layer, patching on each position in turn\n"," out -> q, k, v, pattern all also work, though note that pattern has output shape [layer, pos, head]\n"," We reshape it to plot nicely"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[],"source":["ALL_HEAD_LABELS = [f\"L{i}H{j}\" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]\n","if DO_SLOW_RUNS:\n"," attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)\n"," attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, \"layer pos head -> (layer head) pos\")\n"," imshow(attn_head_out_act_patch_results, \n"," yaxis=\"Head Label\", \n"," xaxis=\"Pos\", \n"," x=[f\"{tok} {i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=ALL_HEAD_LABELS,\n"," title=\"attn_head_out Activation Patching By Pos\")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Patching multiple activation types\n"," Some utilities are provided to patch multiple activations types *in turn*. Note that this is *not* a utility to patch multiple activations at once, it's just a useful scan to get a sense for what's going on in a model\n"," By block: We patch the residual stream at the start of each block, attention output and MLP output over each layer and position"]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"273f73067c394332989ab06fb5f15186","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a3ef4fefd564453ba35fee5a836d6657","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"c145d75160f741fd98be311cda941ccc","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_block_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"e6eefdb1a9984a468e01d5b315d51cd3","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"5234ea6d3a0e40cb83ad760f547befe7","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"93372689514d41a4b94e1e5e52492f98","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"acfbbc29dd5640d1a3c4588daef06611","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"841474694153493e9a74050fde490f84","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)\n","# [markdown]\n","# We can also do by head *and* by position. This is a bit slow, but it can give useful + fine-grained detail"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[],"source":["if DO_SLOW_RUNS:\n"," every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n"," every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n"," imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=ALL_HEAD_LABELS)"]},{"cell_type":"markdown","metadata":{},"source":[" ## Induction Patching\n"," To show how easy it is, lets do that again with induction heads in a 2L Attention Only model\n"," The input will be repeated random tokens eg BOS 1 5 8 9 2 1 5 8 9 2, and we judge the model's ability to predict the second repetition with its induction heads\n"," Lets call A, B and C different (non-repeated) random sequences. We'll start with clean tokens AA and corrupted tokens AB, and see how well the model can predict the second A given the first A"]},{"cell_type":"markdown","metadata":{},"source":[" ### Setup"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Loaded pretrained model attn-only-2l into HookedTransformer\n"]}],"source":["attn_only = HookedTransformer.from_pretrained(\"attn-only-2l\")\n","batch = 4\n","seq_len = 20\n","rand_tokens_A = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n","rand_tokens_B = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n","rand_tokens_C = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n","bos = torch.tensor([attn_only.tokenizer.bos_token_id]*batch)[:, None].to(attn_only.cfg.device)\n","clean_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_A], dim=1).to(attn_only.cfg.device)\n","corrupted_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_B], dim=1).to(attn_only.cfg.device)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[],"source":["clean_logits_induction, clean_cache_induction = attn_only.run_with_cache(clean_tokens_induction)\n","corrupted_logits_induction, corrupted_cache_induction = attn_only.run_with_cache(corrupted_tokens_induction)"]},{"cell_type":"markdown","metadata":{},"source":[" We define our metric as negative loss on the second half (negative loss so that higher is better)\n"," This time we won't normalise our metric"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean baseline: -2.2928695678710938\n","Corrupted baseline: -13.125859260559082\n"]}],"source":["def induction_loss(logits, answer_token_indices=rand_tokens_A):\n"," seq_len = answer_token_indices.shape[1]\n","\n"," # logits: batch x seq_len x vocab_size\n"," # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n"," final_logits = logits[:, -seq_len:-1]\n"," final_log_probs = final_logits.log_softmax(-1)\n"," return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\n","CLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\n","print(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\n","CORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\n","print(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)"]},{"cell_type":"markdown","metadata":{},"source":[" ### Patching"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"e622aaef492c45a58526ab56f37e3964","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"382edd8deceb4452b5bf8e157b95178f","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a9065aae9b2a46e0bb595ae3e195b728","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"8bda86976c0944b3a603a796dbd8917d","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"50e16757155346a999e1eeb94dcbe49a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n","\n","if DO_SLOW_RUNS:\n"," every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n"," every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n"," imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens[0]))], y=ALL_HEAD_LABELS)"]},{"cell_type":"markdown","metadata":{},"source":[" ### Changing the Corrupted Baseline\n"," We can also change the corrupted baseline easily to check what things look like! We'll keep clean as AA, but rather than corrupted as AB, we'll try out:\n"," * BA - This has a corrupted first half, so we expect both keys *and* values to matter. Head output patching should work, but value and key and pattern won't.\n"," * BB - This is still inductiony but with different tokens. So keys, queries and patterns don't matter, head output patching will work, and value will.\n"," * BC - This is just random tokens, so everything is corrupted! The induction head needs queries, keys *and* values, so only output will work."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[],"source":["corrupted_tokens_induction_BA = torch.cat([bos, rand_tokens_B, rand_tokens_A], dim=1).to(attn_only.cfg.device)\n","corrupted_tokens_induction_BB = torch.cat([bos, rand_tokens_B, rand_tokens_B], dim=1).to(attn_only.cfg.device)\n","corrupted_tokens_induction_BC = torch.cat([bos, rand_tokens_B, rand_tokens_C], dim=1).to(attn_only.cfg.device)"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"84f0c8fdecfa454591103ae4678260d5","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"36bcd874bc584d279577a14915f39dc3","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"c2bb1736cf6a4e43a2263bfe6f1d1309","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"64a9057df33f4d8594101fd7ea74c876","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"39df34a94d184e97add046794dd19cdc","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"7ce8d65176ac47198138463641003156","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"7b0dea8bb0114ac69d08f06f73934b65","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a6af42964f0145f6bcd54e28e9e4adfe","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"3c0989d7873c45cfbcce13fe0289548e","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"f0982b478f43462aaf6501c6fec675ed","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a7d2ec5fe6714deea0d72a26994a24ef","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"c131036c69bd4f1e962670df29c4bfaf","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"d2bdafd5cbce446492320082a5f7821e","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"c4d1c5cf6c134f68a71c5bee920adac0","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"cf6e5d5937d646c3abf8f9610dc81609","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n"," \n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BA, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head on BA (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n","every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BB, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head on BB (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n","every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BC, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head on BC (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.13"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2}
diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py
index d7b6f6a9c..2a58200a6 100644
--- a/transformer_lens/ActivationCache.py
+++ b/transformer_lens/ActivationCache.py
@@ -264,7 +264,7 @@ def compute_head_results(
def stack_head_results(
self,
- layer: int,
+ layer: int = -1,
return_labels: bool = False,
incl_remainder: bool = False,
pos_slice: Union[Slice, SliceInput] = None,
@@ -274,7 +274,7 @@ def stack_head_results(
Assumes that the model has been run with use_attn_results=True
Args:
- layer (int): Layer index - heads at all layers strictly before this are included. layer must be in [1, n_layers]
+ layer (int): Layer index - heads at all layers strictly before this are included. layer must be in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer
return_labels (bool, optional): Whether to also return a list of labels of the form "L0H0" for the heads. Defaults to False.
incl_remainder (bool, optional): Whether to return a final term which is "the rest of the residual stream". Defaults to False.
pos_slice (Slice): A slice object to apply to the pos dimension. Defaults to None, do nothing.
@@ -325,6 +325,31 @@ def stack_head_results(
return components, labels
else:
return components
+
+ def stack_activation(
+ self,
+ activation_name: str,
+ layer: int = -1,
+ sublayer_type: Optional[str] = None,
+ ) -> TT[T.layers_covered, ...]:
+ """Returns a stack of all head results (ie residual stream contribution) up to layer L. A good way to decompose the outputs of attention layers into attribution by specific heads. The output shape is exactly the same shape as the activations, just with a leading layers dimension.
+
+ Args:
+ activation_name (str): The name of the activation to be stacked
+ layer (int): Layer index - heads at all layers strictly before this are included. layer must be in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer
+ sublayer_type (str, *optional*): The sub layer type of the activation, passed to utils.get_act_name. Can normally be inferred
+ incl_remainder (bool, optional): Whether to return a final term which is "the rest of the residual stream". Defaults to False.
+ """
+
+ if layer is None or layer == -1:
+ # Default to the residual stream immediately pre unembed
+ layer = self.model.cfg.n_layers
+
+ components = []
+ for l in range(layer):
+ components.append(self[(activation_name, l, sublayer_type)])
+
+ return torch.stack(components, dim=0)
def get_neuron_results(
self,
diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py
index b5a0411f4..a5b437ce8 100644
--- a/transformer_lens/__init__.py
+++ b/transformer_lens/__init__.py
@@ -11,6 +11,7 @@
from .ActivationCache import ActivationCache
from .HookedTransformer import HookedTransformer
from . import loading_from_pretrained as loading
+from . import patching
from . import train
from .past_key_value_caching import (
diff --git a/transformer_lens/patching.py b/transformer_lens/patching.py
new file mode 100644
index 000000000..55db190fd
--- /dev/null
+++ b/transformer_lens/patching.py
@@ -0,0 +1,380 @@
+
+# %%
+"""
+A module for patching activations in a transformer model, and measuring the effect of the patch on the output.
+This implements the activation patching technique for a range of types of activation.
+The structure is to have a single generic_activation_patch function that does everything, and to have a range of specialised functions for specific types of activation.
+
+See this explanation for more https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx
+And check out the Activation Patching in TransformerLens Demo notebook for a demo of how to use this module.
+"""
+
+from __future__ import annotations
+import torch
+from typing import Optional, Union, Dict, Callable, Sequence, Optional, Tuple
+from typing_extensions import Literal
+from torchtyping import TensorType as TT
+
+from transformer_lens.torchtyping_helper import T
+from transformer_lens import HookedTransformer, ActivationCache
+import transformer_lens.utils as utils
+import pandas as pd
+import itertools
+from functools import partial
+from tqdm.auto import tqdm
+
+import einops
+
+# %%
+Logits = torch.Tensor
+AxisNames = Literal["layer", "pos", "head_index", "head", "src_pos", "dest_pos"]
+
+
+# %%
+from typing import Sequence
+def make_df_from_ranges(column_max_ranges: Sequence[int], column_names: Sequence[str]) -> pd.DataFrame:
+ """
+ Takes in a list of column names and max ranges for each column, and returns a dataframe with the cartesian product of the range for each column (ie iterating through all combinations from zero to column_max_range - 1, in order, incrementing the final column first)
+ """
+ rows = list(itertools.product(*[
+ range(axis_max_range) for axis_max_range in column_max_ranges
+ ]))
+ df = pd.DataFrame(rows, columns=column_names)
+ return df
+
+
+# %%
+CorruptedActivation = torch.Tensor
+PatchedActivation = torch.Tensor
+
+def generic_activation_patch(
+ model: HookedTransformer,
+ corrupted_tokens: TT["batch", "pos"],
+ clean_cache: ActivationCache,
+ patching_metric: Callable[[TT[T.batch, T.pos, T.d_vocab]], TT[()]],
+ patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation],
+ activation_name: str,
+ index_axis_names: Optional[Sequence[AxisNames]] = None,
+ index_df: Optional[pd.DataFrame] = None,
+ return_index_df: bool = False,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]:
+ """
+ A generic function to do activation patching, will be specialised to specific use cases.
+
+ Activation patching is about studying the counterfactual effect of a specific activation between a clean run and a corrupted run. The idea is have two inputs, clean and corrupted, which have two different outputs, and differ in some key detail. Eg "The Eiffel Tower is in" vs "The Colosseum is in". Then to take a cached set of activations from the "clean" run, and a set of corrupted.
+
+ Internally, the key function comes from three things: A list of tuples of indices (eg (layer, position, head_index)), a index_to_act_name function which identifies the right activation for each index, a patch_setter function which takes the corrupted activation, the index and the clean cache, and a metric for how well the patched model has recovered.
+
+ The indices can either be given explicitly as a pandas dataframe, or by listing the relevant axis names and having them inferred from the tokens and the model config. It is assumed that the first column is always layer.
+
+ This function then iterates over every tuple of indices, does the relevant patch, and stores it
+
+ Params
+ model: The relevant model
+ corrupted_tokens: The input tokens for the corrupted run
+ clean_cache: The cached activations from the clean run
+ patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
+ patch_setter: A function which acts on (corrupted_activation, index, clean_cache) to edit the activation and patch in the relevant chunk of the clean activation
+ activation_name: The name of the activation being patched
+ index_axis_names: The names of the axes to (fully) iterate over, implicitly fills in index_df
+ index_df: The dataframe of indices, columns are axis names and each row is a tuple of indices. Will be inferred from index_axis_names if not given. When this is input, the output will be a flattened tensor with an element per row of index_df
+ return_index_df: A Boolean flag for whether to return the dataframe of indices too
+
+ Returns
+ patched_output: The tensor of the patching metric for each patch. By default it has one dimension for each index dimension, via index_df set explicitly it is flattened with one element per row.
+ index_df *optional*: The dataframe of indices
+ """
+
+ if index_df is None:
+ assert index_axis_names is not None
+
+ # Get the max range for all possible axes
+ max_axis_range = {
+ "layer": model.cfg.n_layers,
+ "pos": corrupted_tokens.shape[-1],
+ "head_index": model.cfg.n_heads,
+ }
+ max_axis_range["src_pos"] = max_axis_range["pos"]
+ max_axis_range["dest_pos"] = max_axis_range["pos"]
+ max_axis_range["head"] = max_axis_range["head_index"]
+
+ # Get the max range for each axis we iterate over
+ index_axis_max_range = [max_axis_range[axis_name] for axis_name in index_axis_names]
+
+ # Get the dataframe where each row is a tuple of indices
+ index_df = make_df_from_ranges(index_axis_max_range, index_axis_names)
+
+ flattened_output = False
+ else:
+ # A dataframe of indices was provided. Verify that we did not *also* receive index_axis_names
+ assert index_axis_names is None
+ index_axis_max_range = index_df.max().to_list()
+
+ flattened_output = True
+
+ # Create an empty tensor to show the patched metric for each patch
+ if flattened_output:
+ patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device)
+ else:
+ patched_metric_output = torch.zeros(index_axis_max_range, device=model.cfg.device)
+
+ # A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation
+ def patching_hook(corrupted_activation, hook, index, clean_activation):
+ return patch_setter(corrupted_activation, index, clean_activation)
+
+ # Iterate over every list of indices, and make the appropriate patch!
+ for c, index_row in enumerate(tqdm((list(index_df.iterrows())))):
+ index = index_row[1].to_list()
+
+ # The current activation name is just the activation name plus the layer (assumed to be the first element of the input)
+ current_activation_name = utils.get_act_name(activation_name, layer=index[0])
+
+ # The hook function cannot receive additional inputs, so we use partial to include the specific index and the corresponding clean activation
+ current_hook = partial(
+ patching_hook,
+ index = index,
+ clean_activation = clean_cache[current_activation_name]
+ )
+
+ # Run the model with the patching hook and get the logits!
+ patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[(current_activation_name, current_hook)])
+
+ # Calculate the patching metric and store
+ if flattened_output:
+ patched_metric_output[c] = patching_metric(patched_logits).item()
+ else:
+ patched_metric_output[tuple(index)] = patching_metric(patched_logits).item()
+
+ if return_index_df:
+ return patched_metric_output, index_df
+ else:
+ return patched_metric_output
+
+# %%
+# Defining patch setters for various shapes of activations
+def layer_pos_patch_setter(
+ corrupted_activation,
+ index,
+ clean_activation
+ ):
+ """
+ Applies the activation patch where index = [layer, pos]
+
+ Impliitly assumes that the activation axis order is [batch, pos, ...], which is true of everything that is not an attention pattern shaped tensor.
+ """
+ assert len(index)==2
+ layer, pos = index
+ corrupted_activation[:, pos, ...] = clean_activation[:, pos, ...]
+ return corrupted_activation
+
+def layer_pos_head_vector_patch_setter(
+ corrupted_activation,
+ index,
+ clean_activation,
+):
+ """
+ Applies the activation patch where index = [layer, pos, head_index]
+
+ Impliitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
+ """
+ assert len(index)==3
+ layer, pos, head_index = index
+ corrupted_activation[:, pos, head_index] = clean_activation[:, pos, head_index]
+ return corrupted_activation
+
+def layer_head_vector_patch_setter(
+ corrupted_activation,
+ index,
+ clean_activation,
+):
+ """
+ Applies the activation patch where index = [layer, head_index]
+
+ Impliitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
+ """
+ assert len(index)==2
+ layer, head_index = index
+ corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index]
+
+ return corrupted_activation
+
+def layer_head_pattern_patch_setter(
+ corrupted_activation,
+ index,
+ clean_activation,
+):
+ """
+ Applies the activation patch where index = [layer, head_index]
+
+ Impliitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
+ """
+ assert len(index)==2
+ layer, head_index = index
+ corrupted_activation[:, head_index, :, :] = clean_activation[:, head_index, :, :]
+
+ return corrupted_activation
+
+def layer_head_pos_pattern_patch_setter(
+ corrupted_activation,
+ index,
+ clean_activation,
+):
+ """
+ Applies the activation patch where index = [layer, head_index, dest_pos]
+
+ Impliitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
+ """
+ assert len(index)==3
+ layer, head_index, dest_pos = index
+ corrupted_activation[:, head_index, dest_pos, :] = clean_activation[:, head_index, dest_pos, :]
+
+ return corrupted_activation
+
+def layer_head_dest_src_pos_pattern_patch_setter(
+ corrupted_activation,
+ index,
+ clean_activation,
+):
+ """
+ Applies the activation patch where index = [layer, head_index, dest_pos, src_pos]
+
+ Impliitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
+ """
+ assert len(index)==4
+ layer, head_index, dest_pos, src_pos = index
+ corrupted_activation[:, head_index, dest_pos, src_pos] = clean_activation[:, head_index, dest_pos, src_pos]
+
+ return corrupted_activation
+
+# %%
+# Defining activation patching functions for a range of common activation patches.
+get_act_patch_resid_pre = partial(
+ generic_activation_patch,
+ patch_setter = layer_pos_patch_setter,
+ activation_name = "resid_pre",
+ index_axis_names = ("layer", "pos")
+)
+get_act_patch_resid_mid = partial(
+ generic_activation_patch,
+ patch_setter = layer_pos_patch_setter,
+ activation_name = "resid_mid",
+ index_axis_names = ("layer", "pos")
+)
+get_act_patch_attn_out = partial(
+ generic_activation_patch,
+ patch_setter = layer_pos_patch_setter,
+ activation_name = "attn_out",
+ index_axis_names = ("layer", "pos")
+)
+get_act_patch_mlp_out = partial(
+ generic_activation_patch,
+ patch_setter = layer_pos_patch_setter,
+ activation_name = "mlp_out",
+ index_axis_names = ("layer", "pos")
+)
+# %%
+get_act_patch_attn_head_out_by_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_pos_head_vector_patch_setter,
+ activation_name = "z",
+ index_axis_names = ("layer", "pos", "head")
+)
+get_act_patch_attn_head_q_by_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_pos_head_vector_patch_setter,
+ activation_name = "q",
+ index_axis_names = ("layer", "pos", "head")
+)
+get_act_patch_attn_head_k_by_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_pos_head_vector_patch_setter,
+ activation_name = "k",
+ index_axis_names = ("layer", "pos", "head")
+)
+get_act_patch_attn_head_v_by_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_pos_head_vector_patch_setter,
+ activation_name = "v",
+ index_axis_names = ("layer", "pos", "head")
+)
+# %%
+get_act_patch_attn_head_pattern_by_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_head_pos_pattern_patch_setter,
+ activation_name = "pattern",
+ index_axis_names = ("layer", "head_index", "dest_pos")
+)
+get_act_patch_attn_head_pattern_dest_src_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_head_dest_src_pos_pattern_patch_setter,
+ activation_name = "pattern",
+ index_axis_names = ("layer", "head_index", "dest_pos", "src_pos")
+)
+
+# %%
+get_act_patch_attn_head_out_all_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_head_vector_patch_setter,
+ activation_name = "z",
+ index_axis_names = ("layer", "head")
+)
+get_act_patch_attn_head_q_all_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_head_vector_patch_setter,
+ activation_name = "q",
+ index_axis_names = ("layer", "head")
+)
+get_act_patch_attn_head_k_all_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_head_vector_patch_setter,
+ activation_name = "k",
+ index_axis_names = ("layer", "head")
+)
+get_act_patch_attn_head_v_all_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_head_vector_patch_setter,
+ activation_name = "v",
+ index_axis_names = ("layer", "head")
+)
+get_act_patch_attn_head_pattern_all_pos = partial(
+ generic_activation_patch,
+ patch_setter = layer_head_pattern_patch_setter,
+ activation_name = "pattern",
+ index_axis_names = ("layer", "head_index")
+)
+
+# %%
+
+def get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, metric) -> TT["patch_type":5, "layer", "head"]:
+ """Helper function to get activation patching results for every head (across all positions) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, n_heads]
+ """
+ act_patch_results = []
+ act_patch_results.append(get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric))
+ return torch.stack(act_patch_results, dim=0)
+
+def get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, metric) -> TT["patch_type":5, "layer", "pos", "head"]:
+ """Helper function to get activation patching results for every head (across all positions) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, pos, n_heads]
+ """
+ act_patch_results = []
+ act_patch_results.append(get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_attn_head_q_by_pos(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric))
+
+ # Reshape pattern to be compatible with the rest of the results
+ pattern_results = (get_act_patch_attn_head_pattern_by_pos(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(einops.rearrange(pattern_results, "batch head pos -> batch pos head"))
+ return torch.stack(act_patch_results, dim=0)
+
+def get_act_patch_block_every(model, corrupted_tokens, clean_cache, metric) -> TT["patch_type": 3, "layer", "pos"]:
+ """Helper function to get activation patching results for the residual stream (at the start of each block), output of each Attention layer and output of each MLP layer. Wrapper around each's patching function, returns a stacked tensor of shape [3, n_layers, pos]
+ """
+ act_patch_results = []
+ act_patch_results.append(get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric))
+ act_patch_results.append(get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric))
+ return torch.stack(act_patch_results, dim=0)
\ No newline at end of file