From 5f739821122a71e094795dd43c55122a352a1847 Mon Sep 17 00:00:00 2001 From: Neel Nanda <77788841+neelnanda-io@users.noreply.github.com> Date: Sat, 4 Feb 2023 22:00:49 +0000 Subject: [PATCH] Removed broken import in demo --- activation_patching_in_TL_demo.py.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/activation_patching_in_TL_demo.py.ipynb b/activation_patching_in_TL_demo.py.ipynb index a27b3e3d5..27226623d 100644 --- a/activation_patching_in_TL_demo.py.ipynb +++ b/activation_patching_in_TL_demo.py.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","metadata":{},"source":[" # Activation Patching in TransformerLens Demo\n"," This is an accompaniment to [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo). That notebook explains some basic techniques for mech interp of networks, including an overview of activation patching ([summary here](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx)). This demonstrates how to use the Activation Patching utils in TransformerLens.\n"]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","DEBUG_MODE = False\n","try:\n"," import google.colab\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n"," %pip install transformer_lens\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"png\""]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional\n","from functools import partial\n","import copy\n","\n","import itertools\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["import pysvelte\n","\n","import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"]},{"cell_type":"markdown","metadata":{},"source":[" We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training."]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[{"data":{"text/plain":[""]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["torch.set_grad_enabled(False)"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## Activation Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n","answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]\n"," ]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape)==3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE - CORRUPTED_BASELINE)\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we use the patching module to call activation patching utilities"]},{"cell_type":"code","execution_count":12,"metadata":{},"outputs":[],"source":["# Whether to do the runs by head and by position, which are much slower\n","DO_SLOW_RUNS = False"]},{"cell_type":"markdown","metadata":{},"source":[" ### Patching Single Activation Types\n"," We start by patching single types of activation\n"," The general syntax is that the functions are called get_act_patch_... and take in (model, corrupted_tokens, clean_cache, patching_metric)"]},{"cell_type":"markdown","metadata":{},"source":[" We can patch the residual stream at the start of each block over each layer and position\n"," resid_pre -> attn_out, mlp_out, resid_mid all also work"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"3f4b44886e5a47b1b3f3b51ec96e9b85","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(resid_pre_act_patch_results, \n"," yaxis=\"Layer\", \n"," xaxis=\"Position\", \n"," x=[f\"{tok} {i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," title=\"resid_pre Activation Patching\")"]},{"cell_type":"markdown","metadata":{},"source":[" We can patch head outputs over each head in each layer, patching across all positions at once\n"," out -> q, k, v, pattern all also work"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a99b6bfcba454ea1b862da644c1f3154","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(attn_head_out_all_pos_act_patch_results, \n"," yaxis=\"Layer\", \n"," xaxis=\"Head\", \n"," title=\"attn_head_out Activation Patching (All Pos)\")"]},{"cell_type":"markdown","metadata":{},"source":[" We can patch head outputs over each head in each layer, patching on each position in turn\n"," out -> q, k, v, pattern all also work, though note that pattern has output shape [layer, pos, head]\n"," We reshape it to plot nicely"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[],"source":["ALL_HEAD_LABELS = [f\"L{i}H{j}\" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]\n","if DO_SLOW_RUNS:\n"," attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)\n"," attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, \"layer pos head -> (layer head) pos\")\n"," imshow(attn_head_out_act_patch_results, \n"," yaxis=\"Head Label\", \n"," xaxis=\"Pos\", \n"," x=[f\"{tok} {i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=ALL_HEAD_LABELS,\n"," title=\"attn_head_out Activation Patching By Pos\")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Patching multiple activation types\n"," Some utilities are provided to patch multiple activations types *in turn*. Note that this is *not* a utility to patch multiple activations at once, it's just a useful scan to get a sense for what's going on in a model\n"," By block: We patch the residual stream at the start of each block, attention output and MLP output over each layer and position"]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"273f73067c394332989ab06fb5f15186","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_block_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"e6eefdb1a9984a468e01d5b315d51cd3","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)\n","# [markdown]\n","# We can also do by head *and* by position. This is a bit slow, but it can give useful + fine-grained detail"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[],"source":["if DO_SLOW_RUNS:\n"," every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n"," every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n"," imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=ALL_HEAD_LABELS)"]},{"cell_type":"markdown","metadata":{},"source":[" ## Induction Patching\n"," To show how easy it is, lets do that again with induction heads in a 2L Attention Only model\n"," The input will be repeated random tokens eg BOS 1 5 8 9 2 1 5 8 9 2, and we judge the model's ability to predict the second repetition with its induction heads\n"," Lets call A, B and C different (non-repeated) random sequences. We'll start with clean tokens AA and corrupted tokens AB, and see how well the model can predict the second A given the first A"]},{"cell_type":"markdown","metadata":{},"source":[" ### Setup"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Loaded pretrained model attn-only-2l into HookedTransformer\n"]}],"source":["attn_only = HookedTransformer.from_pretrained(\"attn-only-2l\")\n","batch = 4\n","seq_len = 20\n","rand_tokens_A = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n","rand_tokens_B = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n","rand_tokens_C = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n","bos = torch.tensor([attn_only.tokenizer.bos_token_id]*batch)[:, None].to(attn_only.cfg.device)\n","clean_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_A], dim=1).to(attn_only.cfg.device)\n","corrupted_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_B], dim=1).to(attn_only.cfg.device)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[],"source":["clean_logits_induction, clean_cache_induction = attn_only.run_with_cache(clean_tokens_induction)\n","corrupted_logits_induction, corrupted_cache_induction = attn_only.run_with_cache(corrupted_tokens_induction)"]},{"cell_type":"markdown","metadata":{},"source":[" We define our metric as negative loss on the second half (negative loss so that higher is better)\n"," This time we won't normalise our metric"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean baseline: -2.2928695678710938\n","Corrupted baseline: -13.125859260559082\n"]}],"source":["def induction_loss(logits, answer_token_indices=rand_tokens_A):\n"," seq_len = answer_token_indices.shape[1]\n","\n"," # logits: batch x seq_len x vocab_size\n"," # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n"," final_logits = logits[:, -seq_len:-1]\n"," final_log_probs = final_logits.log_softmax(-1)\n"," return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\n","CLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\n","print(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\n","CORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\n","print(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)"]},{"cell_type":"markdown","metadata":{},"source":[" ### Patching"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"e622aaef492c45a58526ab56f37e3964","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n","\n","if DO_SLOW_RUNS:\n"," every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n"," every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n"," imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens[0]))], y=ALL_HEAD_LABELS)"]},{"cell_type":"markdown","metadata":{},"source":[" ### Changing the Corrupted Baseline\n"," We can also change the corrupted baseline easily to check what things look like! We'll keep clean as AA, but rather than corrupted as AB, we'll try out:\n"," * BA - This has a corrupted first half, so we expect both keys *and* values to matter. Head output patching should work, but value and key and pattern won't.\n"," * BB - This is still inductiony but with different tokens. So keys, queries and patterns don't matter, head output patching will work, and value will.\n"," * BC - This is just random tokens, so everything is corrupted! The induction head needs queries, keys *and* values, so only output will work."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[],"source":["corrupted_tokens_induction_BA = torch.cat([bos, rand_tokens_B, rand_tokens_A], dim=1).to(attn_only.cfg.device)\n","corrupted_tokens_induction_BB = torch.cat([bos, rand_tokens_B, rand_tokens_B], dim=1).to(attn_only.cfg.device)\n","corrupted_tokens_induction_BC = torch.cat([bos, rand_tokens_B, rand_tokens_C], dim=1).to(attn_only.cfg.device)"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"84f0c8fdecfa454591103ae4678260d5","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"7ce8d65176ac47198138463641003156","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a7d2ec5fe6714deea0d72a26994a24ef","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BA, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head on BA (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n","every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BB, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head on BB (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n","every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BC, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head on BC (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.13"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} +{"cells":[{"cell_type":"markdown","metadata":{},"source":[" # Activation Patching in TransformerLens Demo\n"," This is an accompaniment to [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo). That notebook explains some basic techniques for mech interp of networks, including an overview of activation patching ([summary here](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx)). This demonstrates how to use the Activation Patching utils in TransformerLens.\n"]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","DEBUG_MODE = False\n","try:\n"," import google.colab\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n"," %pip install transformer_lens\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"png\""]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional\n","from functools import partial\n","import copy\n","\n","import itertools\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"]},{"cell_type":"markdown","metadata":{},"source":[" We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training."]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[{"data":{"text/plain":[""]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["torch.set_grad_enabled(False)"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## Activation Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n","answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]\n"," ]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape)==3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE - CORRUPTED_BASELINE)\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we use the patching module to call activation patching utilities"]},{"cell_type":"code","execution_count":12,"metadata":{},"outputs":[],"source":["# Whether to do the runs by head and by position, which are much slower\n","DO_SLOW_RUNS = False"]},{"cell_type":"markdown","metadata":{},"source":[" ### Patching Single Activation Types\n"," We start by patching single types of activation\n"," The general syntax is that the functions are called get_act_patch_... and take in (model, corrupted_tokens, clean_cache, patching_metric)"]},{"cell_type":"markdown","metadata":{},"source":[" We can patch the residual stream at the start of each block over each layer and position\n"," resid_pre -> attn_out, mlp_out, resid_mid all also work"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"3f4b44886e5a47b1b3f3b51ec96e9b85","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(resid_pre_act_patch_results, \n"," yaxis=\"Layer\", \n"," xaxis=\"Position\", \n"," x=[f\"{tok} {i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," title=\"resid_pre Activation Patching\")"]},{"cell_type":"markdown","metadata":{},"source":[" We can patch head outputs over each head in each layer, patching across all positions at once\n"," out -> q, k, v, pattern all also work"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a99b6bfcba454ea1b862da644c1f3154","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(attn_head_out_all_pos_act_patch_results, \n"," yaxis=\"Layer\", \n"," xaxis=\"Head\", \n"," title=\"attn_head_out Activation Patching (All Pos)\")"]},{"cell_type":"markdown","metadata":{},"source":[" We can patch head outputs over each head in each layer, patching on each position in turn\n"," out -> q, k, v, pattern all also work, though note that pattern has output shape [layer, pos, head]\n"," We reshape it to plot nicely"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[],"source":["ALL_HEAD_LABELS = [f\"L{i}H{j}\" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]\n","if DO_SLOW_RUNS:\n"," attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)\n"," attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, \"layer pos head -> (layer head) pos\")\n"," imshow(attn_head_out_act_patch_results, \n"," yaxis=\"Head Label\", \n"," xaxis=\"Pos\", \n"," x=[f\"{tok} {i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=ALL_HEAD_LABELS,\n"," title=\"attn_head_out Activation Patching By Pos\")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Patching multiple activation types\n"," Some utilities are provided to patch multiple activations types *in turn*. Note that this is *not* a utility to patch multiple activations at once, it's just a useful scan to get a sense for what's going on in a model\n"," By block: We patch the residual stream at the start of each block, attention output and MLP output over each layer and position"]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"273f73067c394332989ab06fb5f15186","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_block_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"e6eefdb1a9984a468e01d5b315d51cd3","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)\n","# [markdown]\n","# We can also do by head *and* by position. This is a bit slow, but it can give useful + fine-grained detail"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[],"source":["if DO_SLOW_RUNS:\n"," every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n"," every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n"," imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=ALL_HEAD_LABELS)"]},{"cell_type":"markdown","metadata":{},"source":[" ## Induction Patching\n"," To show how easy it is, lets do that again with induction heads in a 2L Attention Only model\n"," The input will be repeated random tokens eg BOS 1 5 8 9 2 1 5 8 9 2, and we judge the model's ability to predict the second repetition with its induction heads\n"," Lets call A, B and C different (non-repeated) random sequences. We'll start with clean tokens AA and corrupted tokens AB, and see how well the model can predict the second A given the first A"]},{"cell_type":"markdown","metadata":{},"source":[" ### Setup"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Loaded pretrained model attn-only-2l into HookedTransformer\n"]}],"source":["attn_only = HookedTransformer.from_pretrained(\"attn-only-2l\")\n","batch = 4\n","seq_len = 20\n","rand_tokens_A = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n","rand_tokens_B = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n","rand_tokens_C = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)\n","bos = torch.tensor([attn_only.tokenizer.bos_token_id]*batch)[:, None].to(attn_only.cfg.device)\n","clean_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_A], dim=1).to(attn_only.cfg.device)\n","corrupted_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_B], dim=1).to(attn_only.cfg.device)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[],"source":["clean_logits_induction, clean_cache_induction = attn_only.run_with_cache(clean_tokens_induction)\n","corrupted_logits_induction, corrupted_cache_induction = attn_only.run_with_cache(corrupted_tokens_induction)"]},{"cell_type":"markdown","metadata":{},"source":[" We define our metric as negative loss on the second half (negative loss so that higher is better)\n"," This time we won't normalise our metric"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean baseline: -2.2928695678710938\n","Corrupted baseline: -13.125859260559082\n"]}],"source":["def induction_loss(logits, answer_token_indices=rand_tokens_A):\n"," seq_len = answer_token_indices.shape[1]\n","\n"," # logits: batch x seq_len x vocab_size\n"," # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)\n"," final_logits = logits[:, -seq_len:-1]\n"," final_log_probs = final_logits.log_softmax(-1)\n"," return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()\n","CLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()\n","print(\"Clean baseline:\", CLEAN_BASELINE_INDUCTION)\n","CORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()\n","print(\"Corrupted baseline:\", CORRUPTED_BASELINE_INDUCTION)"]},{"cell_type":"markdown","metadata":{},"source":[" ### Patching"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"e622aaef492c45a58526ab56f37e3964","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n","\n","if DO_SLOW_RUNS:\n"," every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)\n"," every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n"," imshow(every_head_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=CLEAN_BASELINE_INDUCTION, x= [f\"{tok}_{i}\" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens[0]))], y=ALL_HEAD_LABELS)"]},{"cell_type":"markdown","metadata":{},"source":[" ### Changing the Corrupted Baseline\n"," We can also change the corrupted baseline easily to check what things look like! We'll keep clean as AA, but rather than corrupted as AB, we'll try out:\n"," * BA - This has a corrupted first half, so we expect both keys *and* values to matter. Head output patching should work, but value and key and pattern won't.\n"," * BB - This is still inductiony but with different tokens. So keys, queries and patterns don't matter, head output patching will work, and value will.\n"," * BC - This is just random tokens, so everything is corrupted! The induction head needs queries, keys *and* values, so only output will work."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[],"source":["corrupted_tokens_induction_BA = torch.cat([bos, rand_tokens_B, rand_tokens_A], dim=1).to(attn_only.cfg.device)\n","corrupted_tokens_induction_BB = torch.cat([bos, rand_tokens_B, rand_tokens_B], dim=1).to(attn_only.cfg.device)\n","corrupted_tokens_induction_BC = torch.cat([bos, rand_tokens_B, rand_tokens_C], dim=1).to(attn_only.cfg.device)"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"84f0c8fdecfa454591103ae4678260d5","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"7ce8d65176ac47198138463641003156","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a7d2ec5fe6714deea0d72a26994a24ef","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/16 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BA, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head on BA (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n","every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BB, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head on BB (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)\n","every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BC, clean_cache_induction, induction_loss)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head on BC (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=CLEAN_BASELINE_INDUCTION)"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.13"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2}