diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml
index 36ad195ee..481fa220c 100644
--- a/.github/workflows/gh-pages.yml
+++ b/.github/workflows/gh-pages.yml
@@ -20,12 +20,10 @@ jobs:
- uses: actions/checkout@v2
- name: Install Poetry
uses: snok/install-poetry@v1
- with:
- version: 1.4.0
- name: Set up Python
uses: actions/setup-python@v2
with:
- python-version: "3.9"
+ python-version: "3.11"
- name: Install dependencies
run: poetry install --with docs
- name: Build Docs
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 02c4f51d3..3376660f4 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -30,6 +30,7 @@
"gelu",
"githubpages",
"gptj",
+ "howpublished",
"huggingface",
"interpretability",
"isort",
@@ -41,6 +42,7 @@
"Nanda",
"neel",
"neox",
+ "Nitpicky",
"Olah",
"pagename",
"probs",
@@ -51,6 +53,7 @@
"templatedir",
"templatename",
"toctree",
+ "transformerlens",
"Unembed",
"unembedding"
],
diff --git a/docs/make_docs.py b/docs/make_docs.py
index d5a7614bd..da2a98555 100644
--- a/docs/make_docs.py
+++ b/docs/make_docs.py
@@ -1,11 +1,4 @@
-"""
-Generate a markdown table summarizing properties of pretrained models.
-
-This script extracts various properties of pretrained models from the
-`easy_transformer` library, such as the number of parameters, layers, and heads,
-among others, and generates a markdown table. This table is saved to the
-docs directory.
-"""
+"""Build the API Documentation."""
import subprocess
from functools import lru_cache
from pathlib import Path
@@ -76,7 +69,12 @@ def get_property(name, model_name):
def generate_model_table():
- """Generate a markdown table summarizing properties of pretrained models."""
+ """Generate a markdown table summarizing properties of pretrained models.
+
+ This script extracts various properties of pretrained models from the `easy_transformer`
+ library, such as the number of parameters, layers, and heads, among others, and generates a
+ markdown table.
+ """
# Create the table
column_names = [
@@ -115,7 +113,17 @@ def generate_model_table():
def build_docs():
"""Build the docs."""
generate_model_table()
- subprocess.run(["sphinx-build", SOURCE_PATH, BUILD_PATH], check=True)
+
+ subprocess.run(
+ [
+ "sphinx-build",
+ SOURCE_PATH,
+ BUILD_PATH,
+ # "-n", # Nitpicky mode (warn about all missing references)
+ "-W", # Turn warnings into errors
+ ],
+ check=True,
+ )
def docs_hot_reload():
diff --git a/docs/source/conf.py b/docs/source/conf.py
index a2303518b..60f4cd950 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -23,7 +23,6 @@
"sphinx.ext.napoleon",
"myst_parser",
"sphinx.ext.githubpages",
- "sphinx.ext.apidoc",
]
source_suffix = {
diff --git a/docs/source/content/citation.md b/docs/source/content/citation.md
index 9fac69d41..e38b700a6 100644
--- a/docs/source/content/citation.md
+++ b/docs/source/content/citation.md
@@ -1,14 +1,15 @@
-## Citation
+# Citation
Please cite this library as:
-```
-@misc{nandatransformerlens2022,
- title = {TransformerLens},
- author = {Nanda, Neel},
- url = {https://github.com/neelnanda-io/TransformerLens},
- year = {2022}
+
+```BibTeX
+@misc{nanda2022transformerlens,
+ title = {TransformerLens},
+ author = {Neel Nanda},
+ year = {2022},
+ howpublished = {\url{https://github.com/neelnanda-io/TransformerLens}},
}
```
-(This is my best guess for how citing software works, feel free to send a correction!)
+
Also, if you're actually using this for your research, I'd love to chat! Reach out at neelnanda27@gmail.com
diff --git a/docs/source/content/development.md b/docs/source/content/development.md
index ee2bd4c60..086bfe825 100644
--- a/docs/source/content/development.md
+++ b/docs/source/content/development.md
@@ -1,10 +1,10 @@
-## Local Development
+# Local Development
-### DevContainer
+## DevContainer
For a one-click setup of your development environment, this project includes a [DevContainer](https://containers.dev/). It can be used locally with [VS Code](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) or with [GitHub Codespaces](https://github.com/features/codespaces).
-### Manual Setup
+## Manual Setup
This project uses [Poetry](https://python-poetry.org/docs/#installation) for package management. Install as follows (this will also setup your virtual environment):
@@ -17,12 +17,12 @@ Optionally, if you want Jupyter Lab you can run `poetry run pip install jupyterl
Then the library can be imported as `import transformer_lens`.
-### Testing
+## Testing
If adding a feature, please add unit tests for it to the tests folder, and check that it hasn't broken anything major using the existing tests (install pytest and run it in the root TransformerLens/ directory).
To run tests, you can use the following command:
-```
+```shell
poetry run pytest -v transformer_lens/tests
```
diff --git a/docs/source/content/gallery.md b/docs/source/content/gallery.md
index cd4cbdbd0..d3f30262e 100644
--- a/docs/source/content/gallery.md
+++ b/docs/source/content/gallery.md
@@ -1,5 +1,6 @@
-## Gallery
+# Gallery
User contributed examples of the library being used in action:
+
* [Induction Heads Phase Change Replication](https://colab.research.google.com/github/ckkissane/induction-heads-transformer-lens/blob/main/Induction_Heads_Phase_Change.ipynb): A partial replication of [In-Context Learning and Induction Heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html) from Connor Kissane
* [Decision Transformer Interpretability](https://github.com/jbloomAus/DecisionTransformerInterpretability): A set of scripts for training decision transformers which uses transformer lens to view intermediate activations, perform attribution and ablations. A write up of the initial work can be found [here](https://www.lesswrong.com/posts/bBuBDJBYHt39Q5zZy/decision-transformer-interpretability).
diff --git a/docs/source/content/getting_started.md b/docs/source/content/getting_started.md
index eae26a073..459b65b44 100644
--- a/docs/source/content/getting_started.md
+++ b/docs/source/content/getting_started.md
@@ -1,22 +1,21 @@
-## Getting Started
+# Getting Started
**Start with the [main demo](https://neelnanda.io/transformer-lens-demo) to learn how the library works, and the basic features**.
-To see what using it for exploratory analysis in practice looks like, check out [my notebook analysing Indirect Objection Identification](https://neelnanda.io/exploratory-analysis-demo) or [my recording of myself doing research](https://www.youtube.com/watch?v=yo4QvDn-vsU)!
+To see what using it for exploratory analysis in practice looks like, check out [my notebook analysing Indirect Objection Identification](https://neelnanda.io/exploratory-analysis-demo) or [my recording of myself doing research](https://www.youtube.com/watch?v=yo4QvDn-vsU)!
Mechanistic interpretability is a very young and small field, and there are a *lot* of open problems - if you would like to help, please try working on one! **Check out my [list of concrete open problems](https://docs.google.com/document/d/1WONBzNqfKIxERejrrPlQMyKqg7jSFW92x5UMXNrMdPo/edit) to figure out where to start.**. It begins with advice on skilling up, and key resources to check out.
If you're new to transformers, check out my [what is a transformer tutorial](https://neelnanda.io/transformer-tutorial) and [tutorial on coding GPT-2 from scratch](https://neelnanda.io/transformer-tutorial-2) (with [an accompanying template](https://neelnanda.io/transformer-template) to write one yourself!
-### Advice for Reading the Code
+## Advice for Reading the Code
One significant design decision made was to have a single transformer implementation that could support a range of subtly different GPT-style models. This has the upside of interpretability code just working for arbitrary models when you change the model name in `HookedTransformer.from_pretrained`! But it has the significant downside that the code implementing the model (in `HookedTransformer.py` and `components.py`) can be difficult to read. I recommend starting with my [Clean Transformer Demo](https://neelnanda.io/transformer-solution), which is a clean, minimal implementation of GPT-2 with the same internal architecture and activation names as HookedTransformer, but is significantly clearer and better documented.
-### Installation
+## Installation
`pip install git+https://github.com/neelnanda-io/TransformerLens`
Import the library with `import transformer_lens`
(Note: This library used to be known as EasyTransformer, and some breaking changes have been made since the rename. If you need to use the old version with some legacy code, run `pip install git+https://github.com/neelnanda-io/TransformerLens@v1`.)
-
diff --git a/docs/source/content/tutorials.md b/docs/source/content/tutorials.md
index 834676eca..b31cbee32 100644
--- a/docs/source/content/tutorials.md
+++ b/docs/source/content/tutorials.md
@@ -1,14 +1,14 @@
-## Tutorials
+# Tutorials
- **Start with the [main demo](https://neelnanda.io/transformer-lens-demo) to learn how the library works, and the basic features**.
-### Where To Start
+## Where To Start
- To see what using it for exploratory analysis in practice looks like, check out [my notebook analysing Indirect Objection Identification](https://neelnanda.io/exploratory-analysis-demo) or [my recording of myself doing research](https://www.youtube.com/watch?v=yo4QvDn-vsU)!
- [What is a Transformer tutorial](https://neelnanda.io/transformer-tutorial)
-### Demos
+## Demos
- [**Activation Patching in TransformerLens**](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb) - Accompanies the [Exploratory Analysis Demo](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory Analysis Demo.ipynb). This demo explains how to use [Activation Patching](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx) in TransformerLens, a mechanistic interpretability technique that uses causal intervention to identify which activations in a model matter for producing an output.
diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py
index 8813cab07..532bd107b 100644
--- a/transformer_lens/ActivationCache.py
+++ b/transformer_lens/ActivationCache.py
@@ -268,24 +268,21 @@ def logit_attrs(
difference attributions for the residual stack if incorrect_tokens is provided.
Args:
- residual_stack (Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"]):
- stack of components of residual stream to get logit attributions for.
-
- tokens (Union[str, int, Int[torch.Tensor, ""], Int[torch.Tensor, "batch"],
- Int[torch.Tensor, "batch position"]]): tokens to compute logit attributions on.
- incorrect_tokens (Union[str, int, Int[torch.Tensor, ""], Int[torch.Tensor, "batch"],
- Int[torch.Tensor, "batch position"]], optional): if provided, compute attributions
+ residual_stack: Stack of components of residual stream to get logit attributions for.
+ tokens: tokens to compute logit attributions on.
+ incorrect_tokens: if provided, compute attributions
on logit difference between tokens and incorrect_tokens. Must have the same shape as
tokens.
- pos_slice (Slice, optional): The slice to apply layer norm scaling on. Defaults to None,
+ pos_slice: The slice to apply layer norm scaling on. Defaults to None,
do nothing.
- batch_slice (Slice, optional): The slice to take on the batch dimension during layer
+ batch_slice: The slice to take on the batch dimension during layer
norm scaling. Defaults to None, do nothing.
- has_batch_dim (bool, optional): Whether residual_stack has a batch dimension. Defaults
+ has_batch_dim: Whether residual_stack has a batch dimension. Defaults
to True.
+
Returns:
- Components: A [num_components, *batch_and_pos_dims] tensor of the logit attributions or
- logit difference attributions if incorrect_tokens was provided.
+ Components: A tensor of the logit attributions or logit difference attributions if
+ incorrect_tokens was provided.
"""
if not isinstance(pos_slice, Slice):
pos_slice = Slice(pos_slice)
@@ -352,25 +349,25 @@ def decompose_resid(
useful for attributing model behaviour to different components of the residual stream
Args:
- layer (int): The layer to take components up to - by default includes
+ layer: The layer to take components up to - by default includes
resid_pre for that layer and excludes resid_mid and resid_post for that layer.
layer==n_layers means to return all layer outputs incl in the final layer, layer==0
means just embed and pos_embed. The indices are taken such that this gives the
accumulated streams up to the input to layer l
- incl_mid (bool, optional): Whether to return resid_mid for all previous
+ incl_mid: Whether to return resid_mid for all previous
layers. Defaults to False.
- mlp_input (bool, optional): Whether to include attn_out for the current
+ mlp_input: Whether to include attn_out for the current
layer - essentially decomposing the residual stream that's input to the MLP input
rather than the Attn input. Defaults to False.
- mode (str): Values are "all", "mlp" or "attn". "all" returns all
+ mode: Values are "all", "mlp" or "attn". "all" returns all
components, "mlp" returns only the MLP components, and "attn" returns only the
attention components. Defaults to "all".
- apply_ln (bool, optional): Whether to apply LayerNorm to the stack. Defaults to False.
- pos_slice (Slice): A slice object to apply to the pos dimension.
+ apply_ln: Whether to apply LayerNorm to the stack. Defaults to False.
+ pos_slice: A slice object to apply to the pos dimension.
Defaults to None, do nothing.
- incl_embeds (bool): Whether to include embed & pos_embed return_labels (bool, optional):
- Whether to return a list of labels for
- the residual stream components. Useful for labelling graphs. Defaults to True.
+ incl_embeds: Whether to include embed & pos_embed
+ return_labels: Whether to return a list of labels for the residual stream components.
+ Useful for labelling graphs. Defaults to True.
Returns:
Components: A [num_components, batch_size, pos, d_model] tensor of the accumulated
diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py
index df1e27d39..357ca4c63 100644
--- a/transformer_lens/HookedTransformer.py
+++ b/transformer_lens/HookedTransformer.py
@@ -1338,7 +1338,6 @@ def from_pretrained_no_processing(
def init_weights(self):
"""Initialize weights.
-
Initialize weights matrices with a normal of std=initializer_range (default=0.02). This
roughly follows the GPT-2 paper's scheme (but with truncation, and not halving the std for
W_pos).
@@ -1348,7 +1347,7 @@ def init_weights(self):
Weight matrices are set to empty by default (to save space + compute, since they're the bulk
of the parameters), so it is important to call this if you are not loading in pretrained
- weights! Note that this function assumes that weight names being with W_
+ weights! Note that this function assumes that weight names being with `W_`.
Set seed here to ensure determinism.
diff --git a/transformer_lens/SVDInterpreter.py b/transformer_lens/SVDInterpreter.py
index df5694ffe..caaecd448 100644
--- a/transformer_lens/SVDInterpreter.py
+++ b/transformer_lens/SVDInterpreter.py
@@ -32,30 +32,46 @@ def get_singular_vectors(
) -> torch.Tensor:
"""Gets the singular vectors for a given vector type, layer, and optionally head.
- Options:
- - OV: Get the singular vectors of the OV matrix for a particular layer and head.
- - w_in: Get the singular vectors of the w_in matrix for a particular layer.
- - w_out: Get the singular vectors of the w_out matrix for a particular layer.
-
- Returns a (d_vocab, 1, num_vectors) tensor.
-
- This tensor can then be plotted using Neel's PySvelte, as demonstrated in the demo for this feature. The demo also points out some "gotchas" in this feature - numerical instability means inconsistency across devices, and the default HookedTransformer parameters don't replicate the original SVD post very well. So I'd recommend checking out the demo if you want to use this!
+ This tensor can then be plotted using Neel's PySvelte, as demonstrated in the demo for this
+ feature. The demo also points out some "gotchas" in this feature - numerical instability
+ means inconsistency across devices, and the default HookedTransformer parameters don't
+ replicate the original SVD post very well. So I'd recommend checking out the demo if you
+ want to use this!
Example:
+
.. code-block:: python
- build-docsfrom transformer_lens import HookedTransformer, SVDInterpreter
- build-docsmodel = HookedTransformer.from_pretrained('gpt2-medium')
- build-docssvd_interpreter = SVDInterpreter(model)
- build-docsov = svd_interpreter.get_singular_vectors('OV', layer_index=22, head_index=10)
+ from transformer_lens import HookedTransformer, SVDInterpreter
+
+ model = HookedTransformer.from_pretrained('gpt2-medium')
+ svd_interpreter = SVDInterpreter(model)
+
+ ov = svd_interpreter.get_singular_vectors('OV', layer_index=22, head_index=10)
+
+ all_tokens = [model.to_str_tokens(np.array([i])) for i in range(model.cfg.d_vocab)]
+ all_tokens = [all_tokens[i][0] for i in range(model.cfg.d_vocab)]
- build-docsall_tokens = [model.to_str_tokens(np.array([i])) for i in range(model.cfg.d_vocab)]
- build-docsall_tokens = [all_tokens[i][0] for i in range(model.cfg.d_vocab)]
+ def plot_matrix(matrix, tokens, k=10, filter="topk"):
+ pysvelte.TopKTable(
+ tokens=all_tokens,
+ activations=matrix,
+ obj_type="SVD direction",
+ k=k,
+ filter=filter
+ ).show()
- build-docsdef plot_matrix(matrix, tokens, k=10, filter="topk"):
- build-docs pysvelte.TopKTable(tokens=all_tokens, activations=matrix, obj_type="SVD direction", k=k, filter=filter).show()
+ plot_matrix(ov, all_tokens)
- build-docsplot_matrix(ov, all_tokens)"""
+ Args:
+ vector_type: Type of the vector:
+ - "OV": Singular vectors of the OV matrix for a particular layer and head.
+ - "w_in": Singular vectors of the w_in matrix for a particular layer.
+ - "w_out": Singular vectors of the w_out matrix for a particular layer.
+ layer_index: The index of the layer.
+ num_vectors: Number of vectors.
+ head_index: Index of the head.
+ """
if head_index is None:
assert vector_type in [
diff --git a/transformer_lens/evals.py b/transformer_lens/evals.py
index 60362e7c1..710560491 100644
--- a/transformer_lens/evals.py
+++ b/transformer_lens/evals.py
@@ -174,7 +174,7 @@ class IOIDataset(Dataset):
Paper: https://arxiv.org/pdf/2211.00593.pdf
Example:
- --------
+
.. code-block:: python
>>> from transformer_lens.evals import ioi_eval, IOIDataset
@@ -281,22 +281,22 @@ def get_default_nouns():
}
-# %%
@torch.inference_mode()
def ioi_eval(
model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False
):
- """
- Evaluates the model on the Indirect Object Identification task.
-
- dataset must be a torch Dataset that returns a dict:
- {
- 'prompt': torch.LongTensor,
- 'IO': torch.LongTensor,
- 'S': torch.LongTensor
- }
-
- Returns average logit difference and accuracy.
+ """Evaluate the Model on the Indirect Object Identification Task.
+
+ Args:
+ model: HookedTransformer model.
+ dataset: PyTorch Dataset that returns a dict with keys "prompt", "IO", and "S".
+ batch_size: Batch size to use.
+ num_samples: Number of samples to use.
+ tokenizer: Tokenizer to use.
+ symmetric: Whether to use the symmetric version of the task.
+
+ Returns:
+ Average logit difference and accuracy.
"""
if tokenizer is None:
tokenizer = model.tokenizer
diff --git a/transformer_lens/head_detector.py b/transformer_lens/head_detector.py
index 5904a6263..41eb72da9 100644
--- a/transformer_lens/head_detector.py
+++ b/transformer_lens/head_detector.py
@@ -42,73 +42,61 @@ def detect_head(
exclude_current_token: bool = False,
error_measure: ErrorMeasure = "mul",
) -> torch.Tensor:
- """Searches the model (or a set of specific heads, for circuit analysis) for a particular type of attention head.
- This head is specified by a detection pattern, a (sequence_length, sequence_length) tensor representing the attention pattern we expect that type of attention head to show.
- The detection pattern can be also passed not as a tensor, but as a name of one of pre-specified types of attention head (see `HeadName` for available patterns), in which case the tensor is computed within the function itself.
+ """Search for a Particular Type of Attention Head.
- There are two error measures available for quantifying the match between the detection pattern and the actual attention pattern.
+ Searches the model (or a set of specific heads, for circuit analysis) for a particular type of
+ attention head. This head is specified by a detection pattern, a (sequence_length,
+ sequence_length) tensor representing the attention pattern we expect that type of attention head
+ to show. The detection pattern can be also passed not as a tensor, but as a name of one of
+ pre-specified types of attention head (see `HeadName` for available patterns), in which case the
+ tensor is computed within the function itself.
- 1. `"mul"` (default) multiplies both tensors element-wise and divides the sum of the result by the sum of the attention pattern.
- Typically, the detection pattern should in this case contain only ones and zeros, which allows a straightforward interpretation of the score:
- how big fraction of this head's attention is allocated to these specific query-key pairs?
- Using values other than 0 or 1 is not prohibited but will raise a warning (which can be disabled, of course).
- 2. `"abs"` calculates the mean element-wise absolute difference between the detection pattern and the actual attention pattern.
- The "raw result" ranges from 0 to 2 where lower score corresponds to greater accuracy. Subtracting it from 1 maps that range to (-1, 1) interval,
- with 1 being perfect match and -1 perfect mismatch.
+ There are two error measures available for quantifying the match between the detection pattern
+ and the actual attention pattern.
- **Which one should you use?** `"mul"` is likely better for quick or exploratory investigations. For precise examinations where you're trying to
- reproduce as much functionality as possible or really test your understanding of the attention head, you probably want to switch to `"abs"`.
+ 1. `"mul"` (default) multiplies both tensors element-wise and divides the sum of the result by
+ the sum of the attention pattern. Typically, the detection pattern should in this case
+ contain only ones and zeros, which allows a straightforward interpretation of the score: how
+ big fraction of this head's attention is allocated to these specific query-key pairs? Using
+ values other than 0 or 1 is not prohibited but will raise a warning (which can be disabled,
+ of course).
- The advantage of `"abs"` is that you can make more precise predictions, and have that measured in the score.
- You can predict, for instance, 0.2 attention to X, and 0.8 attention to Y, and your score will be better if your prediction is closer.
- The "mul" metric does not allow this, you'll get the same score if attention is 0.2, 0.8 or 0.5, 0.5 or 0.8, 0.2.
+ 2. `"abs"` calculates the mean element-wise absolute difference between the detection pattern
+ and the actual attention pattern. The "raw result" ranges from 0 to 2 where lower score
+ corresponds to greater accuracy. Subtracting it from 1 maps that range to (-1, 1) interval,
+ with 1 being perfect match and -1 perfect mismatch.
+
+ Which one should you use?
+
+ `"mul"` is likely better for quick or exploratory investigations. For precise examinations where
+ you're trying to reproduce as much functionality as possible or really test your understanding
+ of the attention head, you probably want to switch to `"abs"`.
+
+ The advantage of `"abs"` is that you can make more precise predictions, and have that measured
+ in the score. You can predict, for instance, 0.2 attention to X, and 0.8 attention to Y, and
+ your score will be better if your prediction is closer. The "mul" metric does not allow this,
+ you'll get the same score if attention is 0.2, 0.8 or 0.5, 0.5 or 0.8, 0.2.
Args:
- ----------
model: Model being used.
seq: String or list of strings being fed to the model.
- head_name: Name of an existing head in HEAD_NAMES we want to check. Must pass either a head_name or a detection_pattern, but not both!
- detection_pattern: (sequence_length, sequence_length) Tensor representing what attention pattern corresponds to the head we're looking for **or** the name of a pre-specified head. Currently available heads are: `["previous_token_head", "duplicate_token_head", "induction_head"]`.
- heads: If specific attention heads is given here, all other heads' score is set to -1. Useful for IOI-style circuit analysis. Heads can be spacified as a list tuples (layer, head) or a dictionary mapping a layer to heads within that layer that we want to analyze.
- cache: Include the cache to save time if you want.
+ head_name: Name of an existing head in HEAD_NAMES we want to check. Must pass either a
+ head_name or a detection_pattern, but not both!
+ detection_pattern: (sequence_length, sequence_length)nTensor representing what attention
+ pattern corresponds to the head we're looking for or the name of a pre-specified head.
+ Currently available heads are: `["previous_token_head", "duplicate_token_head",
+ "induction_head"]`.
+ heads: If specific attention heads is given here, all other heads' score is set to -1.
+ Useful for IOI-style circuit analysis. Heads can be spacified as a list tuples (layer,
+ head) or a dictionary mapping a layer to heads within that layer that we want to
+ analyze. cache: Include the cache to save time if you want.
exclude_bos: Exclude attention paid to the beginning of sequence token.
exclude_current_token: Exclude attention paid to the current token.
- error_measure: `"mul"` for using element-wise multiplication (default). `"abs"` for using absolute values of element-wise differences as the error measure.
+ error_measure: `"mul"` for using element-wise multiplication. `"abs"` for using absolute
+ values of element-wise differences as the error measure.
Returns:
- ----------
- A (n_layers, n_heads) Tensor representing the score for each attention head.
-
- Example:
- --------
- .. code-block:: python
-
- from transformer_lens import HookedTransformer, utils
- from transformer_lens.head_detector import detect_head
- import plotly.express as px
-
- def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
- px.imshow(
- utils.to_numpy(tensor),
- color_continuous_midpoint=0.0,
- color_continuous_scale="RdBu",
- labels={"x":xaxis, "y":yaxis},
- **kwargs
- ).show(renderer)
-
- model = HookedTransformer.from_pretrained("gpt2-small")
- sequence = "This is a test sequence. This is a test sequence."
-
- attention_score = detect_head(model, sequence, "previous_token_head")
- imshow(
- attention_score,
- zmin=-1, zmax=1,
- xaxis="Head",
- yaxis="Layer",
- title="Previous Head
- Matches"
- )
-
+ Tensor representing the score for each attention head.
"""
cfg = model.cfg
@@ -287,3 +275,4 @@ def compute_head_attention_similarity_score(
abs_diff.fill_diagonal_(0)
return 1 - round((abs_diff.mean() * size).item(), 3)
+ return 1 - round((abs_diff.mean() * size).item(), 3)
diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py
index f992fcf3c..ce86fd77a 100644
--- a/transformer_lens/hook_points.py
+++ b/transformer_lens/hook_points.py
@@ -13,27 +13,22 @@
@dataclass
class LensHandle:
- """
- A dataclass that holds information about a PyTorch hook.
-
- Attributes:
- hook (hooks.RemovableHandle): Reference to the hook's RemovableHandle.
- is_permanent (bool, optional): Indicates if the hook is permanent. Defaults to False.
- context_level (Optional[int], optional): Context level associated with the hooks context
- manager for the given hook. Defaults to None.
- """
+ """Dataclass that holds information about a PyTorch hook."""
hook: hooks.RemovableHandle
+ """Reference to the Hook's Removable Handle."""
+
is_permanent: bool = False
+ """Indicates if the Hook is Permanent."""
+
context_level: Optional[int] = None
+ """Context level associated with the hooks context manager for the given hook."""
-# %%
# Define type aliases
NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str]]]
-# %%
class HookPoint(nn.Module):
"""
A helper class to access intermediate activations in a PyTorch model (inspired by Garcon).
@@ -144,13 +139,23 @@ def layer(self):
# %%
class HookedRootModule(nn.Module):
- """
- A class building on nn.Module to interface nicely with HookPoints
- Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, and run_with_cache to run the model on some input and return a cache of all activations
+ """A class building on nn.Module to interface nicely with HookPoints.
+
+ Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks,
+ and run_with_cache to run the model on some input and return a cache of all activations.
- WARNING: The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add the fixed version, the broken one is still there. To solve this, run_with_hooks will remove hooks at the end by default, and I recommend using the API of this and run_with_cache. If you want to add hooks into global state, I recommend being intentional about this, and I recommend using reset_hooks liberally in your code to remove any accidentally remaining global state.
+ Notes:
- The main time this goes wrong is when you want to use backward hooks (to cache or intervene on gradients). In this case, you need to keep the hooks around as global state until you've run loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)
+ The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the
+ module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add
+ the fixed version, the broken one is still there. To solve this, run_with_hooks will remove
+ hooks at the end by default, and I recommend using the API of this and run_with_cache. If you
+ want to add hooks into global state, I recommend being intentional about this, and I recommend
+ using reset_hooks liberally in your code to remove any accidentally remaining global state.
+
+ The main time this goes wrong is when you want to use backward hooks (to cache or intervene on
+ gradients). In this case, you need to keep the hooks around as global state until you've run
+ loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)
"""
def __init__(self, *args):
@@ -282,7 +287,7 @@ def hooks(
clear_contexts (bool): If True, clears hook contexts whenever hooks are reset.
Example:
- --------
+
.. code-block:: python
with model.hooks(fwd_hooks=my_hooks):
diff --git a/transformer_lens/patching.py b/transformer_lens/patching.py
index e693576d3..b97a95191 100644
--- a/transformer_lens/patching.py
+++ b/transformer_lens/patching.py
@@ -7,7 +7,7 @@
Context:
-Activation Patching is technique introduced in the `ROME paper`_, which
+Activation Patching is technique introduced in the `ROME paper `, which
uses a causal intervention to identify which activations in a model matter for producing some
output. It runs the model on input A, replaces (patches) an activation with that same activation on
input B, and sees how much that shifts the answer from A to B.