diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index a7eb31e70..3347cf0fd 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -297,16 +297,6 @@ def get_residual( # Because tokens only need for defining batch size and sequence length, we can simply synthesize them tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device) - if attention_mask is None: - # If the padding side is left or we are using caching, we need to compute the attention - # mask for the adjustment of absolute positional embeddings and attention masking so - # that pad tokens are not attended. - if prepend_bos is USE_DEFAULT_VALUE: - prepend_bos = self.cfg.default_prepend_bos - attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos).to( - device - ) - if self.cfg.positional_embedding_type == "standard": pos_embed = self.hook_pos_embed( self.pos_embed(tokens, pos_offset, attention_mask) @@ -2237,9 +2227,19 @@ def generate( sampled_tokens_list = [] for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): pos_offset = self.get_pos_offset(past_kv_cache, batch_size) + + tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int) + attention_mask = utils.get_attention_mask( + self.tokenizer, tokens, False if prepend_bos is None else prepend_bos + ).to(device) residual, shortformer_pos_embed = self.get_residual( - embeds, pos_offset, return_shortformer_pos_embed=True, device=device + embeds, + pos_offset, + return_shortformer_pos_embed=True, + device=device, + attention_mask=attention_mask, ) + # While generating, we keep generating logits, throw away all but the final logits, # and then use those logits to sample from the distribution We keep adding the # sampled tokens to the end of tokens.