From 8c966e6198d8439caa753bedc02d8c221e946dbd Mon Sep 17 00:00:00 2001 From: Alan <41682961+alan-cooney@users.noreply.github.com> Date: Sat, 21 Oct 2023 09:07:26 +0800 Subject: [PATCH] Improve ActivationCache docs (#432) --- .vscode/cspell.json | 5 + docs/source/conf.py | 2 +- docs/source/content/contributing.md | 6 +- makefile | 4 +- pyproject.toml | 6 +- transformer_lens/ActivationCache.py | 495 +++++++++++++++++++--------- transformer_lens/utils.py | 102 ++++-- 7 files changed, 421 insertions(+), 199 deletions(-) diff --git a/.vscode/cspell.json b/.vscode/cspell.json index c0e16d413..082c84373 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -6,8 +6,10 @@ "alonso", "arange", "argmax", + "autodiff", "autoregressive", "barez", + "Beartype", "belrose", "bertsimas", "biderman", @@ -18,6 +20,7 @@ "circuitsvis", "Codespaces", "colab", + "collectstart", "colour", "conmy", "cooney", @@ -35,6 +38,7 @@ "evals", "fazl", "firstpage", + "fspath", "furo", "garriga", "gelu", @@ -49,6 +53,7 @@ "interp", "interpretability", "ioannis", + "ipynb", "isort", "janiak", "jaxtyping", diff --git a/docs/source/conf.py b/docs/source/conf.py index 45a450f07..50e8d4985 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -39,7 +39,7 @@ napoleon_include_init_with_doc = True napoleon_use_admonition_for_notes = True -napoleon_custom_sections = ["Motivation:"] +napoleon_custom_sections = ["Motivation:", "Warning:"] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/docs/source/content/contributing.md b/docs/source/content/contributing.md index 0f5803f36..ce66ef258 100644 --- a/docs/source/content/contributing.md +++ b/docs/source/content/contributing.md @@ -62,6 +62,10 @@ You should follow this order: A description of what the function/class does, including as much detail as is necessary to fully understand it. +Warning: + +Any warnings to the user (e.g. common pitfalls). + Examples: Include any examples here. They will be checked with doctest. @@ -109,7 +113,7 @@ You can use LaTeX, but note that as you're placing this in python strings the ba must be repeated (i.e. `\\`). You can write LaTeX inline, or in "display mode". ```reStructuredText -.. math:: (a + b)^2 = a^2 + 2ab + b^2 +:math:`(a + b)^2 = a^2 + 2ab + b^2` ``` ```reStructuredText diff --git a/makefile b/makefile index adc3c05fa..a17630f9b 100644 --- a/makefile +++ b/makefile @@ -15,10 +15,10 @@ acceptance-test: poetry run pytest --cov=transformer_lens/ --cov-report=term-missing --cov-branch tests/acceptance docstring-test: - poetry run pytest transformer_lens/ --doctest-modules --doctest-plus + poetry run pytest transformer_lens/ notebook-test: - poetry run pytest demos/Exploratory_Analysis_Demo.ipynb --nbval + poetry run pytest demos/Exploratory_Analysis_Demo.ipynb test: make unit-test diff --git a/pyproject.toml b/pyproject.toml index e6f82d8ab..16cc9b045 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,11 @@ filterwarnings = [ # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils "ignore:distutils Version classes are deprecated:DeprecationWarning" ] -addopts = "--jaxtyping-packages=transformer_lens,beartype.beartype" +addopts = """--jaxtyping-packages=transformer_lens,beartype.beartype \ +-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning \ +--deselect transformer_lens/utils.py::test_prompt \ +--doctest-modules --doctest-plus \ +--nbval""" [tool.isort] profile = "black" diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 532bd107b..6427dc2e1 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -7,7 +7,8 @@ from __future__ import annotations import logging -from typing import Dict, List, Optional, Tuple, Union +import warnings +from typing import Dict, Iterator, List, Optional, Tuple, Union import einops import numpy as np @@ -24,34 +25,43 @@ class ActivationCache: """Activation Cache. A wrapper around a dictionary of cached activations from a model run, with a variety of helper - functions. In general, any utility which is specifically about editing/processing activations - should be a method here, while any utility which is more general should be a function in - utils.py, and any utility which is specifically about model weights should be in - HookedTransformer.py or components.py. + functions. - NOTE: This is designed to be used with the HookedTransformer class, and will not work with other - models. It's also designed to be used with all activations of HookedTransformer being cached, - and some internal methods will break without that. + This is designed to be used with :class:`transformer_lens.HookedTransformer`, and will not + work with other models. It's also designed to be used with all activations of + :class:`transformer_lens.HookedTransformer` being cached, and some internal methods will break + without that. - WARNING: The biggest footgun and source of bugs in this code will be keeping track of indexes, + The biggest footgun and source of bugs in this code will be keeping track of indexes, dimensions, and the numbers of each. There are several kinds of activations: - Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head] Internal attn - pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape [batch, - head_index, query_pos, key_pos] Attn head results: result. Shape [batch, pos, head_index, - d_model] Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between - activation + layernorm). Shape [batch, pos, d_mlp] Residual stream vectors: resid_pre, - resid_mid, resid_post, attn_out, mlp_out, embed, pos_embed, normalized (output of each LN or - LNPre). Shape [batch, pos, d_model] LayerNorm Scale: scale. Shape [batch, pos, 1] - - Sometimes the batch dimension will be missing because we applied remove_batch_dim (used when - batch_size=1), and we need functions to be robust to that. I THINK I've got everything working, - but could easily be wrong! - - Type-Annotations key: layers_covered is the number of layers queried in functions that stack the - residual stream. batch_and_pos_dims is the set of dimensions from batch and pos - by default - this is ["batch", "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] - if we've removed batch dimension and are applying a pos slice which indexes a specific position. + * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head]. + * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape + [batch, head_index, query_pos, key_pos]. + * Attn head results: result. Shape [batch, pos, head_index, d_model]. + * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation + + layernorm). Shape [batch, pos, d_mlp]. + * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed, + pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model]. + * LayerNorm Scale: scale. Shape [batch, pos, 1]. + + Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when + batch_size=1), and as such all library functions *should* be robust to that. + + Type annotations are in the following form: + + * layers_covered is the number of layers queried in functions that stack the residual stream. + * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch", + "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed + batch dimension and are applying a pos slice which indexes a specific position. + + Args: + cache_dict: + A dictionary of cached activations from a model run. + model: + The model that the activations are from. + has_batch_dim: + Whether the activations have a batch dimension. """ def __init__( @@ -64,11 +74,17 @@ def __init__( self.has_pos_embed = "hook_pos_embed" in self.cache_dict def remove_batch_dim(self) -> ActivationCache: + """Remove the Batch Dimension (if a single batch item). + + Returns: + The ActivationCache with the batch dimension removed. + """ if self.has_batch_dim: for key in self.cache_dict: assert ( self.cache_dict[key].size(0) == 1 - ), f"Cannot remove batch dimension from cache with batch size > 1, for key {key} with shape {self.cache_dict[key].shape}" + ), f"Cannot remove batch dimension from cache with batch size > 1, \ + for key {key} with shape {self.cache_dict[key].shape}" self.cache_dict[key] = self.cache_dict[key][0] self.has_batch_dim = False else: @@ -77,7 +93,13 @@ def remove_batch_dim(self) -> ActivationCache: ) return self - def __repr__(self): + def __repr__(self) -> str: + """Representation of the ActivationCache. + + Special method that returns a string representation of an object. It's normally used to give + a string that can be used to recreate the object, but here we just return a string that + describes the object. + """ return f"ActivationCache with keys {list(self.cache_dict.keys())}" def __getitem__(self, key) -> torch.Tensor: @@ -88,7 +110,8 @@ def __getitem__(self, key) -> torch.Tensor: dimension order as (get_act_name, layer_index, layer_type). Args: - key: The key or shorthand name for the activation to retrieve. + key: + The key or shorthand name for the activation to retrieve. Returns: The cached activation tensor corresponding to the given key. @@ -104,30 +127,53 @@ def __getitem__(self, key) -> torch.Tensor: key = (key[0], self.model.cfg.n_layers + key[1], *key[2:]) return self.cache_dict[utils.get_act_name(*key)] - def __len__(self): + def __len__(self) -> int: + """Length of the ActivationCache. + + Special method that returns the length of an object (in this case the number of different + activations in the cache). + """ return len(self.cache_dict) def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache: """Move the Cache to a Device. - Mostly useful for moving it to CPU after model computation finishes to save GPU memory. - Matmuls will be much slower on the CPU. + Mostly useful for moving the cache to the CPU after model computation finishes to save GPU + memory. Note however that operations will be much slower on the CPU. Note also that some + methods will break unless the model is also moved to the same device, eg + `compute_head_results`. + + Args: + device: + The device to move the cache to (e.g. `torch.device.cpu`). + move_model: + Whether to also move the model to the same device. @deprecated - Note that some methods will break unless the model is also moved to the same device, eg - compute_head_results. """ + # Move model is deprecated as we plan on de-coupling the classes + if move_model is not None: + warnings.warn( + "The 'move_model' parameter is deprecated.", + DeprecationWarning, + ) + self.cache_dict = { key: value.to(device) for key, value in self.cache_dict.items() } if move_model: self.model.to(device) + return self def toggle_autodiff(self, mode: bool = False): - """Set Autodiff to Mode (defaults to turning it off). + """Toggle Autodiff Globally. - WARNING: This is pretty dangerous, since autodiff is global state - this turns off torch's + Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens). + + Warning: + + This is pretty dangerous, since autodiff is global state - this turns off torch's ability to take gradients completely and it's easy to get a bunch of errors if you don't realise what you're doing. @@ -138,27 +184,85 @@ def toggle_autodiff(self, mode: bool = False): than its worth. If you don't want to mess with global state, using torch.inference_mode as a context manager - or decorator achieves similar effects :) + or decorator achieves similar effects: + + >>> with torch.inference_mode(): + ... y = torch.Tensor([1., 2, 3]) + >>> y.requires_grad + False """ - logging.warning(f"Changed the global state, set autodiff to {mode}") + logging.warning("Changed the global state, set autodiff to %s", mode) torch.set_grad_enabled(mode) def keys(self): + """Keys of the ActivationCache. + + Examples: + + >>> from transformer_lens import HookedTransformer + >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") + Loaded pretrained model tiny-stories-1M into HookedTransformer + >>> _logits, cache = model.run_with_cache("Some prompt") + >>> list(cache.keys())[0:3] + ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] + + Returns: + List of all keys. + """ return self.cache_dict.keys() def values(self): + """Values of the ActivationCache. + + Returns: + List of all values. + """ return self.cache_dict.values() def items(self): + """Items of the ActivationCache. + + Returns: + List of all items ((key, value) tuples). + """ return self.cache_dict.items() - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[str, torch.Tensor]]: + """ActivationCache Iterator. + + Special method that returns an iterator over the ActivationCache. Allows looping over the + cache. + + Examples: + + >>> from transformer_lens import HookedTransformer + >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") + Loaded pretrained model tiny-stories-1M into HookedTransformer + >>> _logits, cache = model.run_with_cache("Some prompt") + >>> cache_interesting_names = [] + >>> for key in cache: + ... if not key.startswith("blocks.") or key.startswith("blocks.0"): + ... cache_interesting_names.append(key) + >>> print(cache_interesting_names[0:3]) + ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] + + Returns: + Iterator over the cache. + """ return self.cache_dict.__iter__() - def __len__(self): - return len(self.cache_dict) + def apply_slice_to_batch_dim( + self, batch_slice: Union[Slice, SliceInput] + ) -> ActivationCache: + """Apply a Slice to the Batch Dimension. + + Args: + batch_slice: + The slice to apply to the batch dimension. - def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]): + Returns: + The ActivationCache with the batch dimension sliced. + """ if not isinstance(batch_slice, Slice): batch_slice = Slice(batch_slice) assert ( @@ -176,37 +280,57 @@ def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]): def accumulated_resid( self, layer: Optional[int] = None, - incl_mid: bool = False, - apply_ln: bool = False, - pos_slice: Union[Slice, SliceInput] = None, - mlp_input: bool = False, - return_labels: bool = False, - ) -> Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"]: + incl_mid: Optional[bool] = False, + apply_ln: Optional[bool] = False, + pos_slice: Optional[Union[Slice, SliceInput]] = None, + mlp_input: Optional[bool] = False, + return_labels: Optional[bool] = False, + ) -> Union[ + Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], + Tuple[ + Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str] + ], + ]: """Accumulated Residual Stream. - Returns the accumulated residual stream up to a given layer, ie a stack of previous residual - streams up to that layer's input. This can be thought of as a series of partial values of - the residual stream, where the model gradually accumulates what it wants. + Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit + Lens ` + style analysis, where it can be thought of as what the model "believes" at each point in the + residual stream. + + To project this into the vocabulary space, remember that there is a final layer norm in most + decoder-only transformers. Therefore, you need to first apply the final layer norm (which + can be done with :meth:`apply_ln_to_stack`), and then multiply by the unembedding matrix + (:math:`W_U`). + + If you instead want to look at contributions to the residual stream from each component + (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or + :meth:`get_full_resid_decomposition` if you want contributions broken down further into each + MLP neuron. Args: - layer (int, *optional*): The layer to take components up to - by default includes - resid_pre for that layer and excludes resid_mid and resid_post for that layer. - layer==n_layers, -1 or None means to return all residual streams, including the - final one (ie immediately pre logits). The indices are taken such that this gives - the accumulated streams up to the input to layer l - incl_mid (bool, optional): Whether to return resid_mid for all previous layers. Defaults - to False. - mlp_input (bool, optional): Whether to include resid_mid for the current layer - - essentially giving MLP input rather than Attn input. Defaults to False. - apply_ln (bool, optional): Whether to apply LayerNorm to the stack. Defaults to False. - pos_slice (Slice): A slice object to apply to the pos dimension. Defaults to None, do - nothing. - return_labels (bool, optional): Whether to return a list of labels for the residual - stream components. Useful for labelling graphs. Defaults to True. + layer: + The layer to take components up to - by default includes resid_pre for that layer + and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or + `None` it will return all residual streams, including the final one (i.e. + immediately pre logits). The indices are taken such that this gives the accumulated + streams up to the input to layer l. + incl_mid: + Whether to return `resid_mid` for all previous layers. + apply_ln: + Whether to apply LayerNorm to the stack. + pos_slice: + A slice object to apply to the pos dimension. Defaults to None, do nothing. + mlp_input: + Whether to include resid_mid for the current layer. This essentially gives the MLP + input rather than the attention input. + return_labels: + Whether to return a list of labels for the residual stream components. Useful for + labelling graphs. Returns: - Components: A [num_components, batch_size, pos, d_model] tensor of the accumulated - residual streams. (labels): An optional list of labels for the components. + A tensor of the accumulated residual streams. If `return_labels` is True, also returns a + list of labels for the components (as a tuple in the form `(components, labels)`). """ if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) @@ -264,25 +388,47 @@ def logit_attrs( ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]: """Logit Attributions. - Returns the logit attributions for the residual stack on an input of tokens, or the logit - difference attributions for the residual stack if incorrect_tokens is provided. + Takes a residual stack (typically the residual stream decomposed by components), and + calculates how much each item in the stack "contributes" to specific tokens. + + It does this by: + 1. Getting the residual directions of the tokens (i.e. reversing the unembed) + 2. Taking the dot product of each item in the residual stack, with the token residual + directions. + + Note that if incorrect tokens are provided, it instead takes the difference between the + correct and incorrect tokens (to calculate the residual directions). This is useful as + sometimes we want to know e.g. which components are most responsible for selecting the + correct token rather than an incorrect one. For example in the `Interpretability in the Wild + paper ` prompts such as "John and Mary went to the shops, + John gave a bag to" were investigated, and it was therefore useful to calculate attribution + for the :math:`\\text{Mary} - \\text{John}` residual direction. + + Warning: + + Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When + investigating specific components it's also useful to look at it's impact on all tokens + (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`). Args: - residual_stack: Stack of components of residual stream to get logit attributions for. - tokens: tokens to compute logit attributions on. - incorrect_tokens: if provided, compute attributions - on logit difference between tokens and incorrect_tokens. Must have the same shape as - tokens. - pos_slice: The slice to apply layer norm scaling on. Defaults to None, - do nothing. - batch_slice: The slice to take on the batch dimension during layer - norm scaling. Defaults to None, do nothing. - has_batch_dim: Whether residual_stack has a batch dimension. Defaults - to True. + residual_stack: + Stack of components of residual stream to get logit attributions for. + tokens: + Tokens to compute logit attributions on. + incorrect_tokens: + If provided, compute attributions on logit difference between tokens and + incorrect_tokens. Must have the same shape as tokens. + pos_slice: + The slice to apply layer norm scaling on. Defaults to None, do nothing. + batch_slice: + The slice to take on the batch dimension during layer norm scaling. Defaults to + None, do nothing. + has_batch_dim: + Whether residual_stack has a batch dimension. Defaults to True. Returns: - Components: A tensor of the logit attributions or logit difference attributions if - incorrect_tokens was provided. + A tensor of the logit attributions or logit difference attributions if incorrect_tokens + was provided. """ if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) @@ -309,7 +455,9 @@ def logit_attrs( if tokens.shape != incorrect_tokens.shape: raise ValueError( - f"tokens and incorrect_tokens must have the same shape! (tokens.shape={tokens.shape}, incorrect_tokens.shape={incorrect_tokens.shape})" + f"tokens and incorrect_tokens must have the same shape! \ + (tokens.shape={tokens.shape}, \ + incorrect_tokens.shape={incorrect_tokens.shape})" ) # If incorrect_tokens was provided, take the logit difference @@ -341,7 +489,12 @@ def decompose_resid( pos_slice: Union[Slice, SliceInput] = None, incl_embeds: bool = True, return_labels: bool = False, - ) -> Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"]: + ) -> Union[ + Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], + Tuple[ + Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str] + ], + ]: """Decompose the Residual Stream. Decomposes the residual stream input to layer L into a stack of the output of previous @@ -349,29 +502,37 @@ def decompose_resid( useful for attributing model behaviour to different components of the residual stream Args: - layer: The layer to take components up to - by default includes + layer: + The layer to take components up to - by default includes resid_pre for that layer and excludes resid_mid and resid_post for that layer. layer==n_layers means to return all layer outputs incl in the final layer, layer==0 means just embed and pos_embed. The indices are taken such that this gives the accumulated streams up to the input to layer l - incl_mid: Whether to return resid_mid for all previous - layers. Defaults to False. - mlp_input: Whether to include attn_out for the current + incl_mid: + Whether to return resid_mid for all previous + layers. + mlp_input: + Whether to include attn_out for the current layer - essentially decomposing the residual stream that's input to the MLP input - rather than the Attn input. Defaults to False. - mode: Values are "all", "mlp" or "attn". "all" returns all + rather than the Attn input. + mode: + Values are "all", "mlp" or "attn". "all" returns all components, "mlp" returns only the MLP components, and "attn" returns only the attention components. Defaults to "all". - apply_ln: Whether to apply LayerNorm to the stack. Defaults to False. - pos_slice: A slice object to apply to the pos dimension. + apply_ln: + Whether to apply LayerNorm to the stack. + pos_slice: + A slice object to apply to the pos dimension. Defaults to None, do nothing. - incl_embeds: Whether to include embed & pos_embed - return_labels: Whether to return a list of labels for the residual stream components. - Useful for labelling graphs. Defaults to True. + incl_embeds: + Whether to include embed & pos_embed + return_labels: + Whether to return a list of labels for the residual stream components. + Useful for labelling graphs. Returns: - Components: A [num_components, batch_size, pos, d_model] tensor of the accumulated - residual streams. (labels): An optional list of labels for the components. + A tensor of the accumulated residual streams. If `return_labels` is True, also returns + a list of labels for the components (as a tuple in the form `(components, labels)`). """ if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) @@ -453,16 +614,17 @@ def stack_head_results( notation). Args: - layer (int): Layer index - heads at all layers strictly before this are included. layer - must be in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final - layer - return_labels (bool, optional): Whether to also return a list of labels of the form - "L0H0" for the heads. Defaults to False. - incl_remainder (bool, optional): Whether to return a final term which is "the rest of - the residual stream". Defaults to False. - pos_slice (Slice): A slice object to apply to the pos dimension. Defaults to None, do - nothing. - apply_ln (bool, optional): Whether to apply LayerNorm to the stack. Defaults to False. + layer: + Layer index - heads at all layers strictly before this are included. layer must be + in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. + return_labels: + Whether to also return a list of labels of the form "L0H0" for the heads. + incl_remainder: + Whether to return a final term which is "the rest of the residual stream". + pos_slice: + A slice object to apply to the pos dimension. Defaults to None, do nothing. + apply_ln: + Whether to apply LayerNorm to the stack. """ if not isinstance(pos_slice, Slice): pos_slice = Slice(pos_slice) @@ -524,21 +686,20 @@ def stack_activation( ) -> Float[torch.Tensor, "layers_covered ..."]: """Stack Activations. - Returns a stack of all head results (ie residual stream contribution) up to layer L. A good - way to decompose the outputs of attention layers into attribution by specific heads. The - output shape is exactly the same shape as the activations, just with a leading layers - dimension. + Flexible way to stack activations with a given name. Args: - activation_name (str): The name of the activation to be stacked layer (int): 'Layer - index - heads' at all layers strictly before this are included. layer must be in [1, - n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. - sublayer_type (str, *optional*): The sub layer type of the activation, passed to - utils.get_act_name. Can normally be inferred. - incl_remainder (bool, optional): Whether to return a final term which is "the rest of - the residual stream". Defaults to False. + activation_name: + The name of the activation to be stacked + layer: + 'Layer index - heads' at all layers strictly before this are included. layer must be + in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. + sublayer_type: + The sub layer type of the activation, passed to utils.get_act_name. Can normally be + inferred. + incl_remainder: + Whether to return a final term which is "the rest of the residual stream". """ - if layer is None or layer == -1: # Default to the residual stream immediately pre unembed layer = self.model.cfg.n_layers @@ -562,14 +723,15 @@ def get_neuron_results( to all of them. Does *not* cache these because it's expensive in space and cheap to compute. Args: - layer (int): Layer index neuron_slice (Slice, optional): Slice of the neuron. Defaults - to None. pos_slice (Slice, optional): Slice of the positions. Defaults to None. See - `utils.Slice` for details. - + layer: + Layer index. + neuron_slice: + Slice of the neuron. + pos_slice: + Slice of the positions. Returns: - Tensor: [batch_size, pos, d_mlp, d_model] tensor of the results (d_mlp is the neuron - index axis) + Tensor of the results. """ if type(neuron_slice) is not Slice: assert isinstance(neuron_slice, SliceInput) @@ -614,17 +776,19 @@ def stack_neuron_results( small models or short inputs. Args: - layer (int): Layer index - heads at all layers strictly before this are included. layer - must be in [1, n_layers] - pos_slice (Slice, optional): Slice of the positions. Defaults to None. See utils.Slice - for details. - neuron_slice (Slice, optional): Slice of the neurons. Defaults to None. See utils.Slice - for details. - return_labels (bool, optional): Whether to also return a list of labels of the form - "L0H0" for the heads. Defaults to False. - incl_remainder (bool, optional): Whether to return a final term which is "the rest of - the residual stream". Defaults to False. - apply_ln (bool, optional): Whether to apply LayerNorm to the stack. Defaults to False. + layer: + Layer index - heads at all layers strictly before this are included. layer must be + in [1, n_layers] + pos_slice: + Slice of the positions. + neuron_slice: + Slice of the neurons. + return_labels: + Whether to also return a list of labels of the form "L0H0" for the heads. + incl_remainder: + Whether to return a final term which is "the rest of the residual stream". + apply_ln: + Whether to apply LayerNorm to the stack. """ if layer is None or layer == -1: @@ -704,22 +868,26 @@ def apply_ln_to_stack( If the model does not use LayerNorm, it returns the residual stack unchanged. Args: - residual_stack (torch.Tensor): A tensor, whose final dimension is - d_model. The other trailing dimensions are assumed to be the same as the stored - hook_scale - which may or may not include batch or position dimensions. - layer (int): The layer we're taking the input to. In [0, n_layers], - n_layers means the unembed. None maps to the n_layers case, ie the unembed. - mlp_input (bool, optional): Whether the input is to the MLP or attn - (ie ln2 vs ln1). Defaults to False, ie ln1. If layer==n_layers, must be False, and - we use ln_final - pos_slice (Slice, optional): The slice to take of positions, if residual_stack is not - over the full context, None means do nothing. It is assumed that pos_slice has - already been applied to residual_stack, and this is only applied to the scale. See - utils.Slice for details. Defaults to None, do nothing. - batch_slice (Slice, optional): The slice to take on the batch dimension. + residual_stack: + A tensor, whose final dimension is d_model. The other trailing dimensions are + assumed to be the same as the stored hook_scale - which may or may not include batch + or position dimensions. + layer: + The layer we're taking the input to. In [0, n_layers], n_layers means the unembed. + None maps to the n_layers case, ie the unembed. + mlp_input: + Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1. + If layer==n_layers, must be False, and we use ln_final + pos_slice: + The slice to take of positions, if residual_stack is not over the full context, None + means do nothing. It is assumed that pos_slice has already been applied to + residual_stack, and this is only applied to the scale. See utils.Slice for details. Defaults to None, do nothing. - has_batch_dim (bool, optional): Whether residual_stack has a batch dimension. - Defaults to True. + batch_slice: + The slice to take on the batch dimension. Defaults to None, do nothing. + has_batch_dim: + Whether residual_stack has a batch dimension. + """ if self.model.cfg.normalization_type not in ["LN", "LNPre"]: # The model does not use LayerNorm, so we don't need to do anything. @@ -772,17 +940,22 @@ def get_full_resid_decomposition( that is input into some layer. Args: - layer (int): The layer we're inputting into. layer is in [0, n_layers], if - layer==n_layers (or None) we're inputting into the unembed (the entire stream), if - layer==0 then it's just embed and pos_embed - mlp_input (bool, optional): Are we inputting to the MLP in that layer or the attn? Must - be False for final layer, since that's the unembed. Defaults to False. - expand_neurons (bool, optional): Whether to expand the MLP outputs to give every - neuron's result or just return the MLP layer outputs. Defaults to True. - apply_ln (bool, optional): Whether to apply LayerNorm to the stack. Defaults to False. - pos_slice (Slice, optional): Slice of the positions to take. Defaults to None. See - utils.Slice for details. - return_labels (bool): Whether to return the labels. Defaults to False. + layer: + The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or + None) we're inputting into the unembed (the entire stream), if layer==0 then it's + just embed and pos_embed + mlp_input: + Are we inputting to the MLP in that layer or the attn? Must be False for final + layer, since that's the unembed. + expand_neurons: + Whether to expand the MLP outputs to give every neuron's result or just return the + MLP layer outputs. + apply_ln: + Whether to apply LayerNorm to the stack. + pos_slice: + Slice of the positions to take. + return_labels: + Whether to return the labels. """ if layer is None or layer == -1: # Default to the residual stream immediately pre unembed diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index f0f77febf..5306e502f 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -356,7 +356,6 @@ def sample_logits( return torch.distributions.categorical.Categorical(logits=final_logits).sample() -# %% # Type alias SliceInput: Type = Optional[ Union[ @@ -369,7 +368,8 @@ def sample_logits( np.ndarray, ] ] -""" +"""An object that represents a slice input. It can be a tuple of integers or a slice object. + An optional type alias for a slice input used in the `ActivationCache` module. A `SliceInput` can be one of the following types: @@ -380,14 +380,12 @@ def sample_logits( - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor. `SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module. - -:class:`SliceInput` - An object that represents a slice input. It can be a tuple of integers or a slice object. """ class Slice: - """ + """An object that represents a slice input. It can be a tuple of integers or a slice object. + We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions: Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that) @@ -404,9 +402,6 @@ class Slice: elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3]) elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out). elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices. - - :class: `Slice` - An object that represents a slice input. It can be a tuple of integers or a slice object. """ def __init__( @@ -419,9 +414,6 @@ def __init__( Args: input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing. - Returns: - Slice: A Slice object that can be applied to a tensor. - Raises: ValueError: If the input_slice is not one of the above types. """ @@ -492,9 +484,6 @@ def __repr__( return f"Slice: {self.slice} Mode: {self.mode} " -# %% - - def get_act_name( name: str, layer: Optional[Union[int, str]] = None, @@ -610,27 +599,76 @@ def remove_batch_dim( return tensor +# Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit +# test (because it's name is prefixed `test_`). def test_prompt( prompt: str, answer: str, - model, - prepend_space_to_answer: bool = True, - print_details: bool = True, - prepend_bos: Union[bool, None] = USE_DEFAULT_VALUE, - top_k: int = 10, -): - """ - Function to test whether a model can give the correct answer to a prompt. Intended for exploratory analysis, so it prints things out rather than returning things. - - Works for multi-token answers and multi-token prompts. - - Will always print the ranks of the answer tokens, and if print_details will print the logit and prob for the answer tokens and the top k tokens returned for each answer position. + model, # Can't give type hint due to circular imports + prepend_space_to_answer: Optional[bool] = True, + print_details: Optional[bool] = True, + prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, + top_k: Optional[int] = 10, +) -> None: + """Test if the Model Can Give the Correct Answer to a Prompt. + + Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob), + as well as the top k tokens. Works for multi-token prompts and multi-token answers. + + Warning: + + This will print the results (it does not return them). + + Examples: + + >>> from transformer_lens import HookedTransformer, utils + >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") + Loaded pretrained model tiny-stories-1M into HookedTransformer + + >>> prompt = "Why did the elephant cross the" + >>> answer = "road" + >>> utils.test_prompt(prompt, answer, model) + Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the'] + Tokenized answer: [' road'] + Performance on answer token: + Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road| + Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground| + Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree| + Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road| + Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car| + Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river| + Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street| + Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k| + Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill| + Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing| + Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park| + Ranks of the answer tokens: [(' road', 2)] Args: - prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend - the BOS token to the input (applicable when input is a string). Defaults to None, - implying usage of self.cfg.default_prepend_bos (default is True unless specified - otherwise). Pass True or False to override the default. + prompt: + The prompt string, e.g. "Why did the elephant cross the". + answer: + The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need + to think about if you have a space before the answer here (as e.g. in this example the + answer may really be " road" if the prompt ends without a trailing space). + model: + The model. + prepend_space_to_answer: + Whether or not to prepend a space to the answer. Note this will only ever prepend a + space if the answer doesn't already start with one. + print_details: + Print the prompt (as a string but broken up by token), answer and top k tokens (all + with logit, rank and probability). + prepend_bos: + Overrides self.cfg.default_prepend_bos if set. Whether to prepend + the BOS token to the input (applicable when input is a string). Models generally learn + to use the BOS token as a resting place for attention heads (i.e. a way for them to be + "turned off"). This therefore often improves performance slightly. + top_k: + Top k tokens to print details of (when print_details is set to True). + + Returns: + None (just prints the results directly). """ if prepend_space_to_answer and not answer.startswith(" "): answer = " " + answer @@ -672,7 +710,6 @@ def test_prompt( rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}") -# %% def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]: """ Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions @@ -708,7 +745,6 @@ def composition_scores( return comp_norms / r_norms / l_norms -# %% def get_dataset(dataset_name: str, **kwargs) -> Dataset: """ Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader.