diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index ec2116375..545e39b3b 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -31,6 +31,7 @@ "gelu-2l", "othello-gpt", "tiny-stories-33M", + "bloom-560m", "santacoder", ] text = "Hello world!" @@ -57,6 +58,7 @@ "redwood_attn_2l": 10.530948638916016, "solu-1l": 5.256411552429199, "tiny-stories-33M": 12.203617095947266, + "bloom-560m": 4.1953, } no_processing = [ diff --git a/tests/unit/test_compute_linear_bias.py b/tests/unit/test_compute_linear_bias.py new file mode 100644 index 000000000..0a009889e --- /dev/null +++ b/tests/unit/test_compute_linear_bias.py @@ -0,0 +1,40 @@ +import torch + +from transformer_lens.components import Attention + + +def test_create_alibi_slope(): + n_ctx = 100 + + # Expected result computed non-vectorized way + expected = torch.zeros((n_ctx, n_ctx)) + for row in range(n_ctx): + for col in range(n_ctx): + expected[row, col] = float(min(col - row, 0)) + + # Check against the method's vectorized version + result = Attention.create_alibi_slope(n_ctx) + assert torch.allclose(expected, result) + + +def test_create_alibi_bias(): + n_heads = 2 + n_ctx = 4 + + result = Attention.create_alibi_bias(n_heads, n_ctx, torch.device("cpu")) + + for matrix in result: + n_row, n_col = matrix.size() + slope = -matrix[1, 0] + # Check if upper triangle is all zeros + assert torch.equal(torch.triu(matrix), torch.zeros_like(matrix)) + + ref_lower_triangle = torch.zeros_like(matrix) + for i in range(1, n_row): + for j in range(i): + ref_lower_triangle[i, j] = -slope * (i - j) + + # Check if the lower triangle is decreasing by a constant slope (towards the bottom left corner). + assert torch.equal( + torch.tril(matrix, diagonal=-1), torch.tril(ref_lower_triangle, diagonal=-1) + ) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 357ca4c63..de8a6f27e 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -334,6 +334,10 @@ def input_to_embed( # keys and queries. See HookedTransformerConfig for details residual = embed shortformer_pos_embed = None + elif self.cfg.positional_embedding_type == "alibi": + # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores. + residual = embed + shortformer_pos_embed = None else: raise ValueError( f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}" diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 347c8ae91..2730c3bd4 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -147,6 +147,8 @@ class HookedTransformerConfig: tokenizer_prepends_bos (bool, *optional*): This flag is set by set_tokenizer. It is set to True only when the tokenizer automatically prepends the BOS token if initialized with add_bos_token=True. We need this information to dynamically control bos prepending. + post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults + to False. """ n_layers: int @@ -194,6 +196,7 @@ class HookedTransformerConfig: default_prepend_bos: bool = True dtype: torch.dtype = torch.float32 tokenizer_prepends_bos: Optional[bool] = None + post_embedding_ln: bool = False def __post_init__(self): if self.n_heads == -1: diff --git a/transformer_lens/components.py b/transformer_lens/components.py index 480188713..9c8d663cd 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -32,12 +32,17 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter( torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=cfg.dtype) ) + # Some models (e.g. Bloom) need post embedding layer norm + if cfg.post_embedding_ln: + self.ln = LayerNorm(cfg) def forward( self, tokens: Int[torch.Tensor, "batch pos"] ) -> Float[torch.Tensor, "batch pos d_model"]: # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d] # B acts as a tensor of indices into the second dimension (so >=0 and FactoredMatrix: @@ -533,7 +542,6 @@ def forward( qkv_einops_string = "batch pos head_index d_model" else: qkv_einops_string = "batch pos d_model" - q = self.hook_q( einsum( f"{qkv_einops_string}, head_index d_model d_head \ @@ -593,6 +601,22 @@ def forward( ) / self.attn_scale ) # [batch, head_index, query_pos, key_pos] + + if self.cfg.positional_embedding_type == "alibi": + query_ctx = attn_scores.size(-2) + # The key context length is the number of positions in the past - this includes all positions in the cache + key_ctx = attn_scores.size(-1) + + # only recompute when necessary to increase efficiency. + if self.alibi is None or key_ctx > self.alibi.size(-1): + self.alibi = Attention.create_alibi_bias( + self.cfg.n_heads, key_ctx, self.cfg.device + ) + + attn_scores += self.alibi[ + :, :query_ctx, :key_ctx + ] # [batch, head_index, query_pos, key_pos] + if self.cfg.attention_dir == "causal": # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. attn_scores = self.apply_causal_mask( @@ -757,6 +781,136 @@ def apply_rotary( return torch.cat([x_rotated, x_pass], dim=-1) + @staticmethod + def create_alibi_slope( + n_ctx: int, device: torch.device = None + ) -> Float[torch.Tensor, "query key"]: + """Create an ALiBi Slope Matrix. + + Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar. + + See :meth:`create_alibi_bias` for the full ALiBi bias calculation. + + Examples: + + >>> Attention.create_alibi_slope(3) + tensor([[ 0., 0., 0.], + [-1., 0., 0.], + [-2., -1., 0.]]) + + >>> Attention.create_alibi_slope(4) + tensor([[ 0., 0., 0., 0.], + [-1., 0., 0., 0.], + [-2., -1., 0., 0.], + [-3., -2., -1., 0.]]) + + Args: + n_ctx: The maximum number of tokens in a prompt. + + Returns: + A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower + triangle is decreasing by a constant slope of 1 (towards the bottom left corner). + """ + # set rows as [[0,1,2...]] + rows = torch.arange(n_ctx, device=device).unsqueeze(0) + + # Set cols as [[0],[1],[2]...] + cols = torch.arange(n_ctx, device=device).unsqueeze(1) + + # Use broadcasting to create the desired lower triangular part of the matrix + slope_matrix = rows - cols + + # Use the clamp method to set all positive values (upper right triangle) to + return slope_matrix.clamp(max=0).to(torch.float32) + + @staticmethod + def create_alibi_multipliers( + n_heads: int, device: torch.device = None + ) -> Float[torch.Tensor, "head_idx"]: + """Create the ALiBi Scalar Multipliers for each Head. + + For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and + uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1), + 1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)]. + + See :meth:`create_alibi_bias` for the full ALiBi bias calculation. + + Examples: + + >>> Attention.create_alibi_multipliers(8) + tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039]) + + >>> Attention.create_alibi_multipliers(16) + tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312, + 0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039]) + + Args: + n_heads: The number of heads in a layer. + device: The device to create the tensor on. + + Returns: + A tensor of shape (n_heads,) containing the scalar multiplier for each head. + """ + # Calculate the starting value + start = 2 ** (-8 / n_heads) + + # Generate the indices [0, 1, ..., n_heads-1] + indices = torch.arange(n_heads, device=device) + + # Compute the multipliers, with the starting value being the same as the ratio + multipliers = start * (start**indices) + + return multipliers + + @staticmethod + def create_alibi_bias( + n_heads: int, n_ctx: int, device: torch.device = None + ) -> Float[torch.Tensor, "head_idx query key"]: + """Create the ALiBi Bias for all Heads. + + Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer. + + The broad idea behind ALiBi is to remove the positional encoding from the original transformer + model, and instead apply a bias to each attention score. This bias is proportional to the + distance between the query and key (i.e. it encourage paying less attention to more distant + tokens), and is added to the attention scores before the softmax. It is used in models such as + Bloom. + + Examples: + + >>> Attention.create_alibi_bias(2, 4, torch.device('cpu')) + tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000], + [-0.0625, 0.0000, 0.0000, 0.0000], + [-0.1250, -0.0625, 0.0000, 0.0000], + [-0.1875, -0.1250, -0.0625, 0.0000]], + [[ 0.0000, 0.0000, 0.0000, 0.0000], + [-0.0039, 0.0000, 0.0000, 0.0000], + [-0.0078, -0.0039, 0.0000, 0.0000], + [-0.0117, -0.0078, -0.0039, 0.0000]]]) + + Args: + n_heads: The number of heads in a layer. + n_ctx: The maximum number of tokens in a prompt. + device: The device to create the tensor on. + + Returns: + The ALiBi bias that should be added to the attention scores before the softmax. + """ + # Create the slope matrix + slope: Float[torch.Tensor, "query key"] = Attention.create_alibi_slope( + n_ctx, device + ) + + # Create the scalar multiplier for each head. + multipliers: Float[ + torch.Tensor, "head_idx" + ] = Attention.create_alibi_multipliers(n_heads, device) + + # The ALiBi bias is then m * slope_matrix + alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) + + return alibi_bias + # MLP Layers class MLP(nn.Module): diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 9a9107acd..7c7bf6a6c 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -136,6 +136,7 @@ "stabilityai/stablelm-base-alpha-7b", "stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-tuned-alpha-7b", + "bigscience/bloom-560m", "bigcode/santacoder", ] """Official model names for models on HuggingFace.""" @@ -495,6 +496,7 @@ "stablelm-tuned-alpha-7b", "stablelm-tuned-7b", ], + "bigscience/bloom-560m": ["bloom-560m"], "bigcode/santacoder": ["santacoder"], } """Model aliases for models on HuggingFace.""" @@ -723,6 +725,22 @@ def convert_hf_model_config(model_name: str, **kwargs): "act_fn": "gelu", "attention_dir": "bidirectional", } + elif architecture == "BloomForCausalLM": + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.n_head, + "n_heads": hf_config.n_head, + "d_mlp": hf_config.hidden_size * 4, + "n_layers": hf_config.n_layer, + "n_ctx": 2048, # Capped due to HF Tokenizer Constraints + "d_vocab": hf_config.vocab_size, + "act_fn": "gelu_fast", + "eps": hf_config.layer_norm_epsilon, + "normalization_type": "LN", + "post_embedding_ln": True, + "positional_embedding_type": "alibi", + } + elif architecture == "GPT2LMHeadCustomModel": # santacoder cfg_dict = { @@ -1069,6 +1087,8 @@ def get_pretrained_state_dict( state_dict = convert_llama_weights(hf_model, cfg) elif cfg.original_architecture == "BertForMaskedLM": state_dict = convert_bert_weights(hf_model, cfg) + elif cfg.original_architecture == "BloomForCausalLM": + state_dict = convert_bloom_weights(hf_model, cfg) elif cfg.original_architecture == "GPT2LMHeadCustomModel": state_dict = convert_coder_weights(hf_model, cfg) else: @@ -1651,6 +1671,74 @@ def convert_bert_weights(bert, cfg: HookedTransformerConfig): return state_dict +def convert_bloom_weights(bloom, cfg: HookedTransformerConfig): + state_dict = {} + + state_dict["embed.W_E"] = bloom.transformer.word_embeddings.weight + + # Bloom uses post embedding layer norm + state_dict["embed.ln.w"] = bloom.transformer.word_embeddings_layernorm.weight + state_dict["embed.ln.b"] = bloom.transformer.word_embeddings_layernorm.bias + + for l in range(cfg.n_layers): + state_dict[f"blocks.{l}.ln1.w"] = bloom.transformer.h[l].input_layernorm.weight + state_dict[f"blocks.{l}.ln1.b"] = bloom.transformer.h[l].input_layernorm.bias + + # Bloom attn weight is stored as a fused matrx. BloomAttn: Linear(in=1024, out=3072) + # The .weight returned matrix will be in shape (3072, 1024) + W = bloom.transformer.h[l].self_attention.query_key_value.weight + # First transpose -> (1024, 3072), then split into (d_model, n_heads, 3, d_head) + W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head) + + W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :] + W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads) + W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.W_K"] = W_K + state_dict[f"blocks.{l}.attn.W_V"] = W_V + + qkv_bias = bloom.transformer.h[l].self_attention.query_key_value.bias + qkv_bias = qkv_bias.reshape(cfg.n_heads, 3, cfg.d_head) + + state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[:, 0, :] + state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[:, 1, :] + state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[:, 2, :] + + W_O = bloom.transformer.h[l].self_attention.dense.weight.T # [1024, 1024] + W_O = einops.rearrange( + W_O, "(n h) m->n h m", n=cfg.n_heads + ) # [n_heads, d_head, d_model] + state_dict[f"blocks.{l}.attn.W_O"] = W_O + state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[ + l + ].self_attention.dense.bias + + state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[ + l + ].post_attention_layernorm.weight + state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[ + l + ].post_attention_layernorm.bias + + W_in = bloom.transformer.h[l].mlp.dense_h_to_4h.weight.T + state_dict[f"blocks.{l}.mlp.W_in"] = W_in + state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[ + l + ].mlp.dense_h_to_4h.bias + + W_out = bloom.transformer.h[l].mlp.dense_4h_to_h.weight.T + state_dict[f"blocks.{l}.mlp.W_out"] = W_out + state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[ + l + ].mlp.dense_4h_to_h.bias + state_dict["unembed.W_U"] = bloom.lm_head.weight.T # transpose to match shape + + state_dict["ln_final.w"] = bloom.transformer.ln_f.weight + state_dict["ln_final.b"] = bloom.transformer.ln_f.bias + return state_dict + + def convert_coder_weights(model, cfg: HookedTransformerConfig): state_dict = {}