Skip to content

Commit

Permalink
Fix excess attention_mask getting causing test fails
Browse files Browse the repository at this point in the history
  • Loading branch information
zazamrykh committed Jan 19, 2025
1 parent 792ae65 commit 143c686
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,16 +297,6 @@ def get_residual(
# Because tokens only need for defining batch size and sequence length, we can simply synthesize them
tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device)

if attention_mask is None:
# If the padding side is left or we are using caching, we need to compute the attention
# mask for the adjustment of absolute positional embeddings and attention masking so
# that pad tokens are not attended.
if prepend_bos is USE_DEFAULT_VALUE:
prepend_bos = self.cfg.default_prepend_bos
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos).to(
device
)

if self.cfg.positional_embedding_type == "standard":
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset, attention_mask)
Expand Down Expand Up @@ -2237,9 +2227,19 @@ def generate(
sampled_tokens_list = []
for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
pos_offset = self.get_pos_offset(past_kv_cache, batch_size)

tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int)
attention_mask = utils.get_attention_mask(
self.tokenizer, tokens, False if prepend_bos is None else prepend_bos
).to(device)
residual, shortformer_pos_embed = self.get_residual(
embeds, pos_offset, return_shortformer_pos_embed=True, device=device
embeds,
pos_offset,
return_shortformer_pos_embed=True,
device=device,
attention_mask=attention_mask,
)

# While generating, we keep generating logits, throw away all but the final logits,
# and then use those logits to sample from the distribution We keep adding the
# sampled tokens to the end of tokens.
Expand Down

0 comments on commit 143c686

Please sign in to comment.