Skip to content

Commit

Permalink
Document logit_attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney committed Oct 21, 2023
1 parent deb979a commit 63f6260
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/source/content/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ You can use LaTeX, but note that as you're placing this in python strings the ba
must be repeated (i.e. `\\`). You can write LaTeX inline, or in "display mode".

```reStructuredText
.. math:: (a + b)^2 = a^2 + 2ab + b^2
:math:`(a + b)^2 = a^2 + 2ab + b^2`
```

```reStructuredText
Expand Down
53 changes: 39 additions & 14 deletions transformer_lens/ActivationCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,33 +293,39 @@ def accumulated_resid(
]:
"""Accumulated Residual Stream.
Returns the accumulated residual stream up to a given layer, ie a stack of previous residual
streams up to that layer's input. This can be thought of as a series of partial values of
the residual stream, where the model gradually accumulates what it wants.
Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit
Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`
style analysis, where it can be thought of as what the model "believes" at each point in the
residual stream.
If you instead want to look at contributions to the residual stream from each component
(e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or
:meth:`get_full_resid_decomposition` if you want contributions broken down further into each
MLP neuron.
Args:
layer:
The layer to take components up to - by default includes resid_pre for that layer
and excludes resid_mid and resid_post for that layer. layer==n_layers, -1 or None
means to return all residual streams, including the final one (ie immediately pre
logits). The indices are taken such that this gives the accumulated streams up to
the input to layer l.
and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or
`None` it will return all residual streams, including the final one (i.e.
immediately pre logits). The indices are taken such that this gives the accumulated
streams up to the input to layer l.
incl_mid:
Whether to return resid_mid for all previous layers.
Whether to return `resid_mid` for all previous layers.
apply_ln:
Whether to apply LayerNorm to the stack.
pos_slice:
A slice object to apply to the pos dimension. Defaults to None, do nothing.
mlp_input:
Whether to include resid_mid for the current layer - essentially giving MLP input
rather than Attn input.
Whether to include resid_mid for the current layer. This essentially gives the MLP
input rather than the attention input.
return_labels:
Whether to return a list of labels for the residual stream components. Useful for
labelling graphs.
Returns:
A tensor of the accumulated residual streams. If `return_labels` is True, also returns
a list of labels for the components (as a tuple in the form `(components, labels)`).
A tensor of the accumulated residual streams. If `return_labels` is True, also returns a
list of labels for the components (as a tuple in the form `(components, labels)`).
"""
if not isinstance(pos_slice, Slice):
pos_slice = Slice(pos_slice)
Expand Down Expand Up @@ -377,8 +383,27 @@ def logit_attrs(
) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]:
"""Logit Attributions.
Returns the logit attributions for the residual stack on an input of tokens, or the logit
difference attributions for the residual stack if incorrect_tokens is provided.
Takes a residual stack (typically the residual stream decomposed by components), and
calculates how much each item in the stack "contributes" to specific tokens.
It does this by:
1. Getting the residual directions of the tokens (i.e. reversing the unembed)
2. Taking the dot product of each item in the residual stack, with the token residual
directions.
Note that if incorrect tokens are provided, it instead takes the difference between the
correct and incorrect tokens (to calculate the residual directions). This is useful as
sometimes we want to know e.g. which components are most responsible for selecting the
correct token rather than an incorrect one. For example in the `Interpretability in the Wild
paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops,
John gave a bag to" were investigated, and it was therefore useful to calculate attribution
for the :math:`\\text{Mary} - \\text{John}` residual direction.
Warning:
Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When
investigating specific components it's also useful to look at it's impact on all tokens
(i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`).
Args:
residual_stack:
Expand Down

0 comments on commit 63f6260

Please sign in to comment.