From 9a89bc2e574f34048e4ed85fef814159ed35492e Mon Sep 17 00:00:00 2001 From: Jan Bauer Date: Sun, 19 Jan 2025 15:13:39 +0000 Subject: [PATCH] More cleanup --- transformer_lens/HookedTransformer.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 307f2bfe5..4244e2ee1 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -114,7 +114,6 @@ def __init__( tokenizer: Optional[PreTrainedTokenizerBase] = None, move_to_device: bool = True, default_padding_side: Literal["left", "right"] = "right", - zero_pos_embed: bool = False, ): """Model initialization. @@ -139,7 +138,6 @@ def __init__( ) self.cfg = HookedTransformerConfig.unwrap(cfg) - self.zero_pos_embed = zero_pos_embed if tokenizer is not None: self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) @@ -189,7 +187,7 @@ def __init__( self.hook_tokens = HookPoint() # [batch, pos] self.blocks = nn.ModuleList( - [TransformerBlock(self.cfg, block_index, self.zero_pos_embed) for block_index in range(self.cfg.n_layers)] + [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] ) if self.cfg.normalization_type == "RMS": @@ -358,7 +356,7 @@ def input_to_embed( pos_embed = self.hook_pos_embed( self.pos_embed(tokens, pos_offset, attention_mask) ) # [batch, pos, d_model] - residual = embed + pos_embed * (1. if not self.zero_pos_embed else 0.) # [batch, pos, d_model] + residual = embed + pos_embed # [batch, pos, d_model] shortformer_pos_embed = None elif self.cfg.positional_embedding_type == "shortformer": # If we're using shortformer style attention, we don't add the positional embedding to @@ -367,7 +365,7 @@ def input_to_embed( self.pos_embed(tokens, pos_offset, attention_mask) ) # [batch, pos, d_model] residual = embed - shortformer_pos_embed = pos_embed * (1. if not self.zero_pos_embed else 0.) + shortformer_pos_embed = pos_embed elif self.cfg.positional_embedding_type == "rotary": # Rotary doesn't use positional embeddings, instead they're applied when dot producting # keys and queries. See HookedTransformerConfig for details @@ -573,7 +571,7 @@ def forward( if shortformer_pos_embed is not None: shortformer_pos_embed = shortformer_pos_embed.to( devices.get_device_for_block_index(i, self.cfg) - ) * (1. if not self.zero_pos_embed else 0.) + ) residual = block( residual, @@ -1088,7 +1086,6 @@ def from_pretrained( default_prepend_bos: Optional[bool] = None, default_padding_side: Literal["left", "right"] = "right", dtype="float32", - zero_pos_embed: bool = False, first_n_layers: Optional[int] = None, **from_pretrained_kwargs, ) -> T: @@ -1332,7 +1329,6 @@ def from_pretrained( tokenizer, move_to_device=False, default_padding_side=default_padding_side, - zero_pos_embed=zero_pos_embed, ) model.load_and_process_state_dict( @@ -2238,7 +2234,7 @@ def forward_(*model_args, **model_kwargs): ) final_logits = logits[:, -1, :] - # SAMPLING + # sampling if do_sample: sampled_tokens = utils.sample_logits( final_logits, @@ -2265,10 +2261,11 @@ def forward_(*model_args, **model_kwargs): ) ) - # update the tokens! + # concatenate the new tokens tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1) - # APPEND – we need to clone on the first pass to prevent overwrite + # concatenate the cache + # we need to clone on the first pass to prevent overwrite token_tape = torch.cat([token_tape, sampled_tokens.unsqueeze(-1)], dim=-1) if token_tape is not None else torch.clone(tokens[:, -ctx_length:]) # appends to the prompt tokens if return_cache: def cat_cache_var(key, var_tape, var):