Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables return of activation cache variables during generation #838

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
91 changes: 87 additions & 4 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from transformer_lens.utilities import devices
from transformer_lens.utils import (
USE_DEFAULT_VALUE,
Slice,
init_kaiming_normal_,
init_kaiming_uniform_,
init_xavier_normal_,
Expand Down Expand Up @@ -2043,6 +2044,7 @@ def generate(
prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
return_type: Optional[str] = "input",
return_cache: bool = False,
verbose: bool = True,
) -> Union[Int[torch.Tensor, "batch pos_plus_new_tokens"], str]:
"""Sample Tokens from the Model.
Expand Down Expand Up @@ -2154,22 +2156,67 @@ def generate(
# Currently nothing in HookedTransformer changes with eval, but this is here in case
# that changes in the future.
self.eval()
logits_tape = None
cache_dict_tape = None
token_tape = None

if return_cache:
# defaults from hook_points.py#L510
names_filter = None
device = None
remove_batch_dim: bool = False
incl_bwd: bool = False
reset_hooks_end: bool = True
clear_contexts: bool = True#False
pos_slice = None

pos_slice = Slice.unwrap(pos_slice)

cache_dict, fwd, bwd = self.get_caching_hooks(
names_filter,
incl_bwd,
device,
remove_batch_dim=remove_batch_dim,
pos_slice=pos_slice,
)

def forward_(*model_args, **model_kwargs):
if return_cache:
# cache_dict is changed in-place?
with self.hooks(
fwd_hooks=fwd,
bwd_hooks=bwd,
reset_hooks_end=reset_hooks_end,
clear_contexts=clear_contexts,
):
model_out = self(*model_args, **model_kwargs)
if incl_bwd:
model_out.backward()
return model_out
else:
model_out = self.forward(*model_args,
**model_kwargs,
)
return model_out

for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
# While generating, we keep generating logits, throw away all but the final logits,
# and then use those logits to sample from the distribution We keep adding the
# sampled tokens to the end of tokens.

# forwarding
if use_past_kv_cache:
# We just take the final tokens, as a [batch, 1] tensor
if index > 0:
logits = self.forward(
logits = forward_(
tokens[:, -1:],
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
)
else:
logits = self.forward(
logits = forward_(
tokens,
return_type="logits",
prepend_bos=prepend_bos,
Expand All @@ -2179,14 +2226,15 @@ def generate(
else:
# We input the entire sequence, as a [batch, pos] tensor, since we aren't using
# the cache.
logits = self.forward(
logits = forward_(
tokens,
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
)
final_logits = logits[:, -1, :]

# sampling
if do_sample:
sampled_tokens = utils.sample_logits(
final_logits,
Expand All @@ -2213,8 +2261,40 @@ def generate(
)
)

# concatenate the new tokens
tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1)

# concatenate the cache
# we need to clone on the first pass to prevent overwrite
token_tape = torch.cat([token_tape, sampled_tokens.unsqueeze(-1)], dim=-1) if token_tape is not None else torch.clone(tokens[:, -ctx_length:]) # appends to the prompt tokens
if return_cache:
def cat_cache_var(key, var_tape, var):
if not any(key_ in key for key_ in ['attn_scores', 'hook_pattern']): # only for the vector-valued vars
cat_var = torch.cat([var_tape, var[:, -1:]], dim=1)
return cat_var
else:
var_tape = torch.nn.functional.pad(var_tape, (0,1,0,1), value=0) # right-pads the last two dimensions
slice1 = var[:, :, -1:, :]
T = slice1.shape[-1]
var_tape[..., -1:, -T:] = slice1

# Update for x[:, :, :, -1]
slice2 = var[:, :, :, -1:]
var_tape[..., -T:, -1:] = slice2
return var_tape

cache_dict_tape = (
{k: cat_cache_var(k, cache_dict_tape[k], cache_dict[k]) for k in cache_dict}
if cache_dict_tape is not None
else {k: torch.clone(v) for k, v in cache_dict.items()} # initializes the dict with the initial cache
)

logits_tape = (
torch.cat([logits_tape, logits[:, -1:]], dim=1)
if logits_tape is not None
else torch.clone(logits[:, -ctx_length:])
)

if stop_at_eos and finished_sequences.all():
break

Expand All @@ -2226,7 +2306,10 @@ def generate(
return self.tokenizer.decode(tokens[0])

else:
return tokens
if return_cache:
return token_tape, logits_tape, cache_dict_tape # consider wrapping in ActivationCache()
else:
return tokens

# Give access to all weights as properties.
@property
Expand Down
Loading