Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for bloom-560m model #434

Merged
merged 43 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
38f2ff4
bloom support
Oct 21, 2023
6f1dd96
Merge branch 'main' of github.com:SeuperHakkerJa/TransformerLens into…
Oct 21, 2023
d698085
refine comments, fix bug, add test and notebooks
Oct 23, 2023
8194f06
fixing build error
Oct 23, 2023
70d3dd2
fix import order
Oct 23, 2023
e9623af
fix run black
Oct 23, 2023
5edc1ea
fix more formatting problems
Oct 23, 2023
03a707c
fix compute_alibi_tensor
Oct 27, 2023
bedb53f
fix docstring
Oct 27, 2023
e7c20dd
fix expand alibi
Oct 28, 2023
f1588cf
fix util docs
Oct 28, 2023
8fd8940
fix run blck error
Oct 28, 2023
2720cc5
delete unwanted comments
Oct 28, 2023
0ce7b22
add unit test for expand
Oct 28, 2023
8cf50f2
fix calculating alibi
Oct 28, 2023
b57cda9
change type hint test.
Oct 28, 2023
dbb8435
add test
Oct 28, 2023
aa4ab7f
simplify test
Oct 28, 2023
d09d246
delete test
Oct 28, 2023
58ba044
delete compute linear bias unit test
Oct 28, 2023
96bf38b
add back unit test, alibi functions to static
Nov 3, 2023
6359fce
add tensor shape to type hint
Nov 3, 2023
72af85b
remove init huge alibi
Nov 3, 2023
3987795
fix style
Nov 3, 2023
5006d85
fix style
Nov 3, 2023
7ae0955
add back comment
Nov 3, 2023
0801ad6
Remove notebook demo
alan-cooney Nov 9, 2023
df1e3a9
Add comment on n_ctx
alan-cooney Nov 10, 2023
ba60f93
Add docs hot reloading instructions for contributors (#436)
alan-cooney Oct 22, 2023
35f145e
Make unit & acceptance tests run in parallel (#435)
alan-cooney Oct 22, 2023
3b606ed
Update GitHub CD Actions (#437)
alan-cooney Oct 22, 2023
0b2c71c
Organise & fix README (#430)
alan-cooney Oct 22, 2023
960fd06
Update README.md (#440)
jbloomAus Oct 23, 2023
28d4542
Relax cuda requirements (#442)
alan-cooney Oct 26, 2023
1673e53
Add tests to the main demo and push to the website (#441)
bryce13950 Oct 28, 2023
7f86848
Fix docs command typo (#444)
alan-cooney Oct 28, 2023
5410d12
New model: bigcode/santacoder (#445)
ojh31 Nov 8, 2023
751c91a
Fix merge part 3
alan-cooney Nov 10, 2023
20a02e8
Organise & fix README (#430)
alan-cooney Oct 22, 2023
f178e66
Update README.md (#440)
jbloomAus Oct 23, 2023
820a206
Relax cuda requirements (#442)
alan-cooney Oct 26, 2023
6c4a9c0
Add tests to the main demo and push to the website (#441)
bryce13950 Oct 28, 2023
0c3ae68
Merge branch 'main' into main
alan-cooney Nov 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ def input_to_embed(
# keys and queries. See HookedTransformerConfig for details
residual = embed
shortformer_pos_embed = None
#TODO: alibi embedding doesnt do anything
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a TODO? Should Alibi do something here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no longer needed, deleted.

elif self.cfg.positional_embedding_type == 'alibi':
residual = embed
shortformer_pos_embed = None
else:
raise ValueError(
f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
Expand Down
5 changes: 5 additions & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ class HookedTransformerConfig:
dtype: torch.dtype = torch.float32
tokenizer_prepends_bos: Optional[bool] = None

# bloom flags
post_embedding_layer_norm: bool = False



def __post_init__(self):
if self.n_heads == -1:
self.n_heads = self.d_model // self.d_head
Expand Down
68 changes: 66 additions & 2 deletions transformer_lens/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Dict, Optional, Tuple, Union

import einops
import math
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -32,12 +33,17 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter(
torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=cfg.dtype)
)
# bloom needs post embedding layer norm
if cfg.post_embedding_layer_norm:
self.ln = LayerNorm(cfg)
alan-cooney marked this conversation as resolved.
Show resolved Hide resolved

def forward(
self, tokens: Int[torch.Tensor, "batch pos"]
) -> Float[torch.Tensor, "batch pos d_model"]:
# If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
# B acts as a tensor of indices into the second dimension (so >=0 and <b)
if self.cfg.post_embedding_layer_norm:
return self.ln(self.W_E[tokens, :])
return self.W_E[tokens, :]


Expand Down Expand Up @@ -303,7 +309,7 @@ def forward(
]:
if self.cfg.dtype not in [torch.float32, torch.float64]:
x = x.to(torch.float32)

x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length]
scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
(x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
Expand Down Expand Up @@ -478,6 +484,8 @@ def __init__(
)
self.register_buffer("rotary_sin", sin)
self.register_buffer("rotary_cos", cos)



@property
def OV(self) -> FactoredMatrix:
Expand Down Expand Up @@ -533,7 +541,6 @@ def forward(
qkv_einops_string = "batch pos head_index d_model"
else:
qkv_einops_string = "batch pos d_model"

q = self.hook_q(
einsum(
f"{qkv_einops_string}, head_index d_model d_head \
Expand Down Expand Up @@ -593,6 +600,27 @@ def forward(
)
/ self.attn_scale
) # [batch, head_index, query_pos, key_pos]

alan-cooney marked this conversation as resolved.
Show resolved Hide resolved
# alibi encoding before applying causal mask
if self.cfg.positional_embedding_type == 'alibi':
#TODO: not sure about the side effect of not using standard, double check
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A reminder for myself to double check any potential side effect of setting embedding type to something other than standard, no longer needed, deleted!

batch_size = attn_scores.size(0)
seq_len = attn_scores.size(-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be -1? Note that when generating text the attention scores are not square

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it should be set to key_length, changed to -1. Thanks!

additive_mask = torch.ones(batch_size, seq_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small point but I think it may be clearer to move this additive_mask into build_alibi_tensor, &
then it's easier to explain (instead we can just pass the relevant sizes to that function). What do
you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it needs to have its device set if not (so that it's on the same device as QK)

dtype = self.cfg.dtype if self.cfg.dtype in [torch.float32, torch.float64] else 'torch.float32'
alibi = self.build_alibi_tensor(
attention_mask=additive_mask,
num_heads=self.cfg.n_heads,
dtype=dtype
).to(attn_scores.device)

# Huggingface impl uses torch.Tensor.baddbmm, with alpha = 1/sqrt(d_head), and beta=1
# and alibi.baddbmm(q,k) = beta * alibi + alpha * (q@k),
# here the `attn_scores` is already scaled by a factor of self.attn_scale,
# we only need to add alibi matrix to the result
assert alibi.shape == (attn_scores.size(0), attn_scores.size(1), 1, attn_scores.size(-1)), f"alibi shape {alibi.shape}, expecting {attn_scores.shape}"
attn_scores += alibi # [batch, head_index, query_pos, key_pos]

if self.cfg.attention_dir == "causal":
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
attn_scores = self.apply_causal_mask(
Expand Down Expand Up @@ -756,7 +784,43 @@ def apply_rotary(
x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin

return torch.cat([x_rotated, x_pass], dim=-1)
def build_alibi_tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_attention_linear_bias or create_alibi_bias ? I'm terrible at naming things, so not the
best person to suggest here, but it feels like we shouldn't have tensor in the name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_attention_linear_bias sounds great to me. (I was naming it build_alibi_tensor only because it was named so in HF code)

self,
attention_mask: torch.Tensor, # batch pos
num_heads: int,
dtype: torch.dtype
) -> Float[torch.Tensor, "batch head_index 1 pos"]:
"""
https://github.com/huggingface/transformers/blob/21dc5859421cf0d7d82d374b10f533611745a8c5/src/transformers/models/bloom/modeling_bloom.py#L86
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)

if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)

arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
# originally it returns tensor of shape batch * head_index, 1, pos
return alibi.reshape(batch_size, num_heads, 1, seq_length).to(dtype)

# MLP Layers
class MLP(nn.Module):
Expand Down
84 changes: 84 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
"stabilityai/stablelm-base-alpha-7b",
"stabilityai/stablelm-tuned-alpha-3b",
"stabilityai/stablelm-tuned-alpha-7b",
"bigscience/bloom-560m",
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -494,6 +495,9 @@
"stablelm-tuned-alpha-7b",
"stablelm-tuned-7b",
],
"bigscience/bloom-560m": [
"bloom-560m"
],
}
"""Model aliases for models on HuggingFace."""

Expand Down Expand Up @@ -721,6 +725,23 @@ def convert_hf_model_config(model_name: str, **kwargs):
"act_fn": "gelu",
"attention_dir": "bidirectional",
}
elif architecture == 'BloomForCausalLM':
cfg_dict = {
"d_model" : hf_config.hidden_size,
"d_head" : hf_config.hidden_size // hf_config.n_head,
"n_heads" : hf_config.n_head,
"d_mlp": hf_config.hidden_size * 4,
"n_layers": hf_config.n_layer,
"n_ctx": 2048, # is there a variable for this?
"d_vocab": hf_config.vocab_size,
"act_fn" : "gelu_fast",
"eps": hf_config.layer_norm_epsilon,
"normalization_type": "LN", # double check this
"post_embedding_layer_norm": True,
"positional_embedding_type": 'alibi'
}

# print("bloom config", cfg_dict)
else:
raise NotImplementedError(f"{architecture} is not currently supported.")
# All of these models use LayerNorm
Expand Down Expand Up @@ -1043,6 +1064,8 @@ def get_pretrained_state_dict(
state_dict = convert_llama_weights(hf_model, cfg)
elif cfg.original_architecture == "BertForMaskedLM":
state_dict = convert_bert_weights(hf_model, cfg)
elif cfg.original_architecture == "BloomForCausalLM":
state_dict = convert_bloom_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
Expand Down Expand Up @@ -1623,6 +1646,65 @@ def convert_bert_weights(bert, cfg: HookedTransformerConfig):
return state_dict


#TODO: bloom weight conversion
def convert_bloom_weights(bloom, cfg: HookedTransformerConfig):
state_dict = {}

state_dict["embed.W_E"] = bloom.transformer.word_embeddings.weight

# Bloom uses post embedding layer norm
state_dict["embed.ln.w"] = bloom.transformer.word_embeddings_layernorm.weight
print(state_dict['embed.ln.w'][:5])
state_dict["embed.ln.b"] = bloom.transformer.word_embeddings_layernorm.bias



for l in range(cfg.n_layers):
state_dict[f"blocks.{l}.ln1.w"] = bloom.transformer.h[l].input_layernorm.weight
state_dict[f"blocks.{l}.ln1.b"] = bloom.transformer.h[l].input_layernorm.bias

# Bloom attn weight is stored as a fused matrx. BloomAttn: Linear(in=1024, out=3072)
# The .weight returned matrix will be in shape (3072, 1024)
W = bloom.transformer.h[l].self_attention.query_key_value.weight
# First transpose -> (1024, 3072), then split into (d_model, n_heads, 3, d_head)
W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)

W_Q, W_K, W_V = W_split[...,0,:], W_split[...,1,:], W_split[...,2,:]
W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads)
W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads)
W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
state_dict[f"blocks.{l}.attn.W_K"] = W_K
state_dict[f"blocks.{l}.attn.W_V"] = W_V

qkv_bias = bloom.transformer.h[l].self_attention.query_key_value.bias
qkv_bias = qkv_bias.reshape(cfg.n_heads, 3, cfg.d_head)

state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[:, 0, :]
state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[:, 1, :]
state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[:, 2, :]

W_O = bloom.transformer.h[l].self_attention.dense.weight.T #[1024, 1024]
W_O = einops.rearrange(W_O, "(n h) m->n h m", n=cfg.n_heads) # [n_heads, d_head, d_model]
state_dict[f"blocks.{l}.attn.W_O"] = W_O
state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[l].self_attention.dense.bias

state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[l].post_attention_layernorm.weight
state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[l].post_attention_layernorm.bias

W_in = bloom.transformer.h[l].mlp.dense_h_to_4h.weight.T
state_dict[f"blocks.{l}.mlp.W_in"] = W_in
state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[l].mlp.dense_h_to_4h.bias

W_out = bloom.transformer.h[l].mlp.dense_4h_to_h.weight.T
state_dict[f"blocks.{l}.mlp.W_out"] = W_out
state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[l].mlp.dense_4h_to_h.bias
state_dict["unembed.W_U"] = bloom.lm_head.weight.T # why transpose? cuz right mult?

state_dict["ln_final.w"] = bloom.transformer.ln_f.weight
state_dict["ln_final.b"] = bloom.transformer.ln_f.bias
return state_dict

@dataclasses.dataclass
class Config:
d_model: int = 768
Expand Down Expand Up @@ -1660,3 +1742,5 @@ def get_basic_config(model_name: str, **kwargs) -> Config:
]
}
)

# %%
Loading