Skip to content

Commit

Permalink
Fix excess attention_mask getting causing test fails
Browse files Browse the repository at this point in the history
  • Loading branch information
zazamrykh committed Jan 19, 2025
1 parent 792ae65 commit 33664d0
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,16 +297,6 @@ def get_residual(
# Because tokens only need for defining batch size and sequence length, we can simply synthesize them
tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device)

if attention_mask is None:
# If the padding side is left or we are using caching, we need to compute the attention
# mask for the adjustment of absolute positional embeddings and attention masking so
# that pad tokens are not attended.
if prepend_bos is USE_DEFAULT_VALUE:
prepend_bos = self.cfg.default_prepend_bos
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos).to(
device
)

if self.cfg.positional_embedding_type == "standard":
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset, attention_mask)
Expand Down Expand Up @@ -2237,9 +2227,13 @@ def generate(
sampled_tokens_list = []
for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
pos_offset = self.get_pos_offset(past_kv_cache, batch_size)

tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int)
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos).to(device)
residual, shortformer_pos_embed = self.get_residual(
embeds, pos_offset, return_shortformer_pos_embed=True, device=device
embeds, pos_offset, return_shortformer_pos_embed=True, device=device, attention_mask=attention_mask
)

# While generating, we keep generating logits, throw away all but the final logits,
# and then use those logits to sample from the distribution We keep adding the
# sampled tokens to the end of tokens.
Expand Down

0 comments on commit 33664d0

Please sign in to comment.