From 63f626085fa4684923405d2ef0d7b59fe314bdc2 Mon Sep 17 00:00:00 2001 From: Alan Cooney <41682961+alan-cooney@users.noreply.github.com> Date: Sat, 21 Oct 2023 08:26:11 +0800 Subject: [PATCH] Document logit_attrs --- docs/source/content/contributing.md | 2 +- transformer_lens/ActivationCache.py | 53 +++++++++++++++++++++-------- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/docs/source/content/contributing.md b/docs/source/content/contributing.md index 2e33eaa1b..ce66ef258 100644 --- a/docs/source/content/contributing.md +++ b/docs/source/content/contributing.md @@ -113,7 +113,7 @@ You can use LaTeX, but note that as you're placing this in python strings the ba must be repeated (i.e. `\\`). You can write LaTeX inline, or in "display mode". ```reStructuredText -.. math:: (a + b)^2 = a^2 + 2ab + b^2 +:math:`(a + b)^2 = a^2 + 2ab + b^2` ``` ```reStructuredText diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 260f22537..0ab84c28b 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -293,33 +293,39 @@ def accumulated_resid( ]: """Accumulated Residual Stream. - Returns the accumulated residual stream up to a given layer, ie a stack of previous residual - streams up to that layer's input. This can be thought of as a series of partial values of - the residual stream, where the model gradually accumulates what it wants. + Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit + Lens ` + style analysis, where it can be thought of as what the model "believes" at each point in the + residual stream. + + If you instead want to look at contributions to the residual stream from each component + (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or + :meth:`get_full_resid_decomposition` if you want contributions broken down further into each + MLP neuron. Args: layer: The layer to take components up to - by default includes resid_pre for that layer - and excludes resid_mid and resid_post for that layer. layer==n_layers, -1 or None - means to return all residual streams, including the final one (ie immediately pre - logits). The indices are taken such that this gives the accumulated streams up to - the input to layer l. + and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or + `None` it will return all residual streams, including the final one (i.e. + immediately pre logits). The indices are taken such that this gives the accumulated + streams up to the input to layer l. incl_mid: - Whether to return resid_mid for all previous layers. + Whether to return `resid_mid` for all previous layers. apply_ln: Whether to apply LayerNorm to the stack. pos_slice: A slice object to apply to the pos dimension. Defaults to None, do nothing. mlp_input: - Whether to include resid_mid for the current layer - essentially giving MLP input - rather than Attn input. + Whether to include resid_mid for the current layer. This essentially gives the MLP + input rather than the attention input. return_labels: Whether to return a list of labels for the residual stream components. Useful for labelling graphs. Returns: - A tensor of the accumulated residual streams. If `return_labels` is True, also returns - a list of labels for the components (as a tuple in the form `(components, labels)`). + A tensor of the accumulated residual streams. If `return_labels` is True, also returns a + list of labels for the components (as a tuple in the form `(components, labels)`). """ if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) @@ -377,8 +383,27 @@ def logit_attrs( ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]: """Logit Attributions. - Returns the logit attributions for the residual stack on an input of tokens, or the logit - difference attributions for the residual stack if incorrect_tokens is provided. + Takes a residual stack (typically the residual stream decomposed by components), and + calculates how much each item in the stack "contributes" to specific tokens. + + It does this by: + 1. Getting the residual directions of the tokens (i.e. reversing the unembed) + 2. Taking the dot product of each item in the residual stack, with the token residual + directions. + + Note that if incorrect tokens are provided, it instead takes the difference between the + correct and incorrect tokens (to calculate the residual directions). This is useful as + sometimes we want to know e.g. which components are most responsible for selecting the + correct token rather than an incorrect one. For example in the `Interpretability in the Wild + paper ` prompts such as "John and Mary went to the shops, + John gave a bag to" were investigated, and it was therefore useful to calculate attribution + for the :math:`\\text{Mary} - \\text{John}` residual direction. + + Warning: + + Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When + investigating specific components it's also useful to look at it's impact on all tokens + (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`). Args: residual_stack: