Skip to content

Commit

Permalink
HookedSAETransformer (#536)
Browse files Browse the repository at this point in the history
* implement HookedSAETransformer

* clean up imports

* apply format

* only recompute error if use_error_term

* add tests

* run format

* fix import

* match to hooks API

* improve doc strings

* improve demo

* address Arthur feedback

* try to fix indent:

* try to fix indent again

* change doc code block
  • Loading branch information
ckkissane authored Apr 30, 2024
1 parent 1139caf commit ca6b8db
Show file tree
Hide file tree
Showing 8 changed files with 19,799 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ TransformerLens lets you load in 50+ different open source language models, and
activations of the model to you. You can cache any internal activation in the model, and add in
functions to edit, remove or replace these activations as the model runs.

~~ [OCTOBER SURVEY HERE](https://forms.gle/bw7U3PfioacDtFmT8) ~~
The library also now supports mechanistic interpretability with SAEs (sparse autoencoders)! With [HookedSAETransformer](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/hooked-sae-transformer/demos/HookedSAETransformerDemo.ipynb), you can splice in SAEs during inference and cache + intervene on SAE activations. We recommend [SAELens](https://github.com/jbloomAus/SAELens) (built on top of TransformerLens) for training SAEs.

## Quick Start

Expand All @@ -51,6 +51,7 @@ logits, activations = model.run_with_cache("Hello World")
* [Introduction to the Library and Mech
Interp](https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp)
* [Demo of Main TransformerLens Features](https://neelnanda.io/transformer-lens-demo)
* [Demo of HookedSAETransformer Features](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/hooked-sae-transformer/demos/HookedSAETransformerDemo.ipynb)

## Gallery

Expand Down
18,616 changes: 18,616 additions & 0 deletions demos/HookedSAETransformerDemo.ipynb

Large diffs are not rendered by default.

191 changes: 191 additions & 0 deletions tests/unit/test_hooked_sae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import einops
import pytest
import torch

from transformer_lens import HookedSAE, HookedSAEConfig, HookedSAETransformer

MODEL = "solu-1l"
prompt = "Hello World!"


class Counter:
def __init__(self):
self.count = 0

def inc(self, *args, **kwargs):
self.count += 1


@pytest.fixture(scope="module")
def model():
model = HookedSAETransformer.from_pretrained(MODEL)
yield model
model.reset_saes()


def get_sae_config(model, act_name):
site_to_size = {
"hook_z": model.cfg.d_head * model.cfg.n_heads,
"hook_mlp_out": model.cfg.d_model,
"hook_resid_pre": model.cfg.d_model,
"hook_post": model.cfg.d_mlp,
}
site = act_name.split(".")[-1]
d_in = site_to_size[site]
return HookedSAEConfig(d_in=d_in, d_sae=d_in * 2, hook_name=act_name)


@pytest.mark.parametrize(
"act_name",
[
"blocks.0.attn.hook_z",
"blocks.0.hook_mlp_out",
"blocks.0.mlp.hook_post",
"blocks.0.hook_resid_pre",
],
)
def test_forward_reconstructs_input(model, act_name):
"""Verfiy that the HookedSAE returns an output with the same shape as the input activations."""
sae_cfg = get_sae_config(model, act_name)
hooked_sae = HookedSAE(sae_cfg)

_, cache = model.run_with_cache(prompt, names_filter=act_name)
x = cache[act_name]

sae_output = hooked_sae(x)
assert sae_output.shape == x.shape


@pytest.mark.parametrize(
"act_name",
[
"blocks.0.attn.hook_z",
"blocks.0.hook_mlp_out",
"blocks.0.mlp.hook_post",
"blocks.0.hook_resid_pre",
],
)
def test_run_with_cache(model, act_name):
"""Verifies that run_with_cache caches SAE activations"""
sae_cfg = get_sae_config(model, act_name)
hooked_sae = HookedSAE(sae_cfg)

_, cache = model.run_with_cache(prompt, names_filter=act_name)
x = cache[act_name]

sae_output, cache = hooked_sae.run_with_cache(x)
assert sae_output.shape == x.shape

assert "hook_sae_input" in cache
assert "hook_sae_acts_pre" in cache
assert "hook_sae_acts_post" in cache
assert "hook_sae_recons" in cache
assert "hook_sae_output" in cache


@pytest.mark.parametrize(
"act_name",
[
"blocks.0.attn.hook_z",
"blocks.0.hook_mlp_out",
"blocks.0.mlp.hook_post",
"blocks.0.hook_resid_pre",
],
)
def test_run_with_hooks(model, act_name):
"""Verifies that run_with_hooks works with SAE activations"""
c = Counter()
sae_cfg = get_sae_config(model, act_name)
hooked_sae = HookedSAE(sae_cfg)

_, cache = model.run_with_cache(prompt, names_filter=act_name)
x = cache[act_name]

sae_hooks = [
"hook_sae_input",
"hook_sae_acts_pre",
"hook_sae_acts_post",
"hook_sae_recons",
"hook_sae_output",
]

sae_output = hooked_sae.run_with_hooks(
x, fwd_hooks=[(sae_hook_name, c.inc) for sae_hook_name in sae_hooks]
)
assert sae_output.shape == x.shape

assert c.count == len(sae_hooks)


@pytest.mark.parametrize(
"act_name",
[
"blocks.0.attn.hook_z",
"blocks.0.hook_mlp_out",
"blocks.0.mlp.hook_post",
"blocks.0.hook_resid_pre",
],
)
def test_error_term(model, act_name):
"""Verifies that that if we use error_terms, HookedSAE returns an output that is equal to the input activations."""
sae_cfg = get_sae_config(model, act_name)
sae_cfg.use_error_term = True
hooked_sae = HookedSAE(sae_cfg)

_, cache = model.run_with_cache(prompt, names_filter=act_name)
x = cache[act_name]

sae_output = hooked_sae(x)
assert sae_output.shape == x.shape
assert torch.allclose(sae_output, x, atol=1e-6)


# %%
@pytest.mark.parametrize(
"act_name",
[
"blocks.0.attn.hook_z",
"blocks.0.hook_mlp_out",
"blocks.0.mlp.hook_post",
"blocks.0.hook_resid_pre",
],
)
def test_feature_grads_with_error_term(model, act_name):
"""Verifies that pytorch backward computes the correct feature gradients when using error_terms. Motivated by the need to compute feature gradients for attribution patching."""

# Load SAE
sae_cfg = get_sae_config(model, act_name)
sae_cfg.use_error_term = True
hooked_sae = HookedSAE(sae_cfg)

# Get input activations
_, cache = model.run_with_cache(prompt, names_filter=act_name)
x = cache[act_name]

# Cache gradients with respect to feature acts
hooked_sae.reset_hooks()
grad_cache = {}

def backward_cache_hook(act, hook):
grad_cache[hook.name] = act.detach()

hooked_sae.add_hook("hook_sae_acts_post", backward_cache_hook, "bwd")
hooked_sae.add_hook("hook_sae_output", backward_cache_hook, "bwd")

sae_output = hooked_sae(x)
assert torch.allclose(sae_output, x, atol=1e-6)
value = sae_output.sum()
value.backward()
hooked_sae.reset_hooks()

# Compute gradient analytically
if act_name.endswith("hook_z"):
reshaped_output_grad = einops.rearrange(
grad_cache["hook_sae_output"], "... n_heads d_head -> ... (n_heads d_head)"
)
analytic_grad = reshaped_output_grad @ hooked_sae.W_dec.T
else:
analytic_grad = grad_cache["hook_sae_output"] @ hooked_sae.W_dec.T

# Compare analytic gradient with pytorch computed gradient
assert torch.allclose(grad_cache["hook_sae_acts_post"], analytic_grad, atol=1e-6)
Loading

0 comments on commit ca6b8db

Please sign in to comment.