Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Bauer committed Jan 19, 2025
1 parent cb056c6 commit 9a89bc2
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def __init__(
tokenizer: Optional[PreTrainedTokenizerBase] = None,
move_to_device: bool = True,
default_padding_side: Literal["left", "right"] = "right",
zero_pos_embed: bool = False,
):
"""Model initialization.
Expand All @@ -139,7 +138,6 @@ def __init__(
)

self.cfg = HookedTransformerConfig.unwrap(cfg)
self.zero_pos_embed = zero_pos_embed

if tokenizer is not None:
self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
Expand Down Expand Up @@ -189,7 +187,7 @@ def __init__(
self.hook_tokens = HookPoint() # [batch, pos]

self.blocks = nn.ModuleList(
[TransformerBlock(self.cfg, block_index, self.zero_pos_embed) for block_index in range(self.cfg.n_layers)]
[TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)]
)

if self.cfg.normalization_type == "RMS":
Expand Down Expand Up @@ -358,7 +356,7 @@ def input_to_embed(
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset, attention_mask)
) # [batch, pos, d_model]
residual = embed + pos_embed * (1. if not self.zero_pos_embed else 0.) # [batch, pos, d_model]
residual = embed + pos_embed # [batch, pos, d_model]
shortformer_pos_embed = None
elif self.cfg.positional_embedding_type == "shortformer":
# If we're using shortformer style attention, we don't add the positional embedding to
Expand All @@ -367,7 +365,7 @@ def input_to_embed(
self.pos_embed(tokens, pos_offset, attention_mask)
) # [batch, pos, d_model]
residual = embed
shortformer_pos_embed = pos_embed * (1. if not self.zero_pos_embed else 0.)
shortformer_pos_embed = pos_embed
elif self.cfg.positional_embedding_type == "rotary":
# Rotary doesn't use positional embeddings, instead they're applied when dot producting
# keys and queries. See HookedTransformerConfig for details
Expand Down Expand Up @@ -573,7 +571,7 @@ def forward(
if shortformer_pos_embed is not None:
shortformer_pos_embed = shortformer_pos_embed.to(
devices.get_device_for_block_index(i, self.cfg)
) * (1. if not self.zero_pos_embed else 0.)
)

residual = block(
residual,
Expand Down Expand Up @@ -1088,7 +1086,6 @@ def from_pretrained(
default_prepend_bos: Optional[bool] = None,
default_padding_side: Literal["left", "right"] = "right",
dtype="float32",
zero_pos_embed: bool = False,
first_n_layers: Optional[int] = None,
**from_pretrained_kwargs,
) -> T:
Expand Down Expand Up @@ -1332,7 +1329,6 @@ def from_pretrained(
tokenizer,
move_to_device=False,
default_padding_side=default_padding_side,
zero_pos_embed=zero_pos_embed,
)

model.load_and_process_state_dict(
Expand Down Expand Up @@ -2238,7 +2234,7 @@ def forward_(*model_args, **model_kwargs):
)
final_logits = logits[:, -1, :]

# SAMPLING
# sampling
if do_sample:
sampled_tokens = utils.sample_logits(
final_logits,
Expand All @@ -2265,10 +2261,11 @@ def forward_(*model_args, **model_kwargs):
)
)

# update the tokens!
# concatenate the new tokens
tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1)

# APPEND – we need to clone on the first pass to prevent overwrite
# concatenate the cache
# we need to clone on the first pass to prevent overwrite
token_tape = torch.cat([token_tape, sampled_tokens.unsqueeze(-1)], dim=-1) if token_tape is not None else torch.clone(tokens[:, -ctx_length:]) # appends to the prompt tokens
if return_cache:
def cat_cache_var(key, var_tape, var):
Expand Down

0 comments on commit 9a89bc2

Please sign in to comment.