-
Notifications
You must be signed in to change notification settings - Fork 319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for bloom-560m model #434
Changes from 2 commits
38f2ff4
6f1dd96
d698085
8194f06
70d3dd2
e9623af
5edc1ea
03a707c
bedb53f
e7c20dd
f1588cf
8fd8940
2720cc5
0ce7b22
8cf50f2
b57cda9
dbb8435
aa4ab7f
d09d246
58ba044
96bf38b
6359fce
72af85b
3987795
5006d85
7ae0955
0801ad6
df1e3a9
ba60f93
35f145e
3b606ed
0b2c71c
960fd06
28d4542
1673e53
7f86848
5410d12
751c91a
20a02e8
f178e66
820a206
6c4a9c0
0c3ae68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from typing import Dict, Optional, Tuple, Union | ||
|
||
import einops | ||
import math | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
@@ -32,12 +33,17 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): | |
self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter( | ||
torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=cfg.dtype) | ||
) | ||
# bloom needs post embedding layer norm | ||
if cfg.post_embedding_layer_norm: | ||
self.ln = LayerNorm(cfg) | ||
alan-cooney marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def forward( | ||
self, tokens: Int[torch.Tensor, "batch pos"] | ||
) -> Float[torch.Tensor, "batch pos d_model"]: | ||
# If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d] | ||
# B acts as a tensor of indices into the second dimension (so >=0 and <b) | ||
if self.cfg.post_embedding_layer_norm: | ||
return self.ln(self.W_E[tokens, :]) | ||
return self.W_E[tokens, :] | ||
|
||
|
||
|
@@ -303,7 +309,7 @@ def forward( | |
]: | ||
if self.cfg.dtype not in [torch.float32, torch.float64]: | ||
x = x.to(torch.float32) | ||
|
||
x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length] | ||
scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( | ||
(x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() | ||
|
@@ -478,6 +484,8 @@ def __init__( | |
) | ||
self.register_buffer("rotary_sin", sin) | ||
self.register_buffer("rotary_cos", cos) | ||
|
||
|
||
|
||
@property | ||
def OV(self) -> FactoredMatrix: | ||
|
@@ -533,7 +541,6 @@ def forward( | |
qkv_einops_string = "batch pos head_index d_model" | ||
else: | ||
qkv_einops_string = "batch pos d_model" | ||
|
||
q = self.hook_q( | ||
einsum( | ||
f"{qkv_einops_string}, head_index d_model d_head \ | ||
|
@@ -593,6 +600,27 @@ def forward( | |
) | ||
/ self.attn_scale | ||
) # [batch, head_index, query_pos, key_pos] | ||
|
||
alan-cooney marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# alibi encoding before applying causal mask | ||
if self.cfg.positional_embedding_type == 'alibi': | ||
#TODO: not sure about the side effect of not using standard, double check | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this mean? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A reminder for myself to double check any potential side effect of setting embedding type to something other than |
||
batch_size = attn_scores.size(0) | ||
seq_len = attn_scores.size(-2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be -1? Note that when generating text the attention scores are not square There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it should be set to key_length, changed to -1. Thanks! |
||
additive_mask = torch.ones(batch_size, seq_len) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small point but I think it may be clearer to move this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also it needs to have its device set if not (so that it's on the same device as QK) |
||
dtype = self.cfg.dtype if self.cfg.dtype in [torch.float32, torch.float64] else 'torch.float32' | ||
alibi = self.build_alibi_tensor( | ||
attention_mask=additive_mask, | ||
num_heads=self.cfg.n_heads, | ||
dtype=dtype | ||
).to(attn_scores.device) | ||
|
||
# Huggingface impl uses torch.Tensor.baddbmm, with alpha = 1/sqrt(d_head), and beta=1 | ||
# and alibi.baddbmm(q,k) = beta * alibi + alpha * (q@k), | ||
# here the `attn_scores` is already scaled by a factor of self.attn_scale, | ||
# we only need to add alibi matrix to the result | ||
assert alibi.shape == (attn_scores.size(0), attn_scores.size(1), 1, attn_scores.size(-1)), f"alibi shape {alibi.shape}, expecting {attn_scores.shape}" | ||
attn_scores += alibi # [batch, head_index, query_pos, key_pos] | ||
|
||
if self.cfg.attention_dir == "causal": | ||
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. | ||
attn_scores = self.apply_causal_mask( | ||
|
@@ -756,7 +784,43 @@ def apply_rotary( | |
x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin | ||
|
||
return torch.cat([x_rotated, x_pass], dim=-1) | ||
def build_alibi_tensor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self, | ||
attention_mask: torch.Tensor, # batch pos | ||
num_heads: int, | ||
dtype: torch.dtype | ||
) -> Float[torch.Tensor, "batch head_index 1 pos"]: | ||
""" | ||
https://github.com/huggingface/transformers/blob/21dc5859421cf0d7d82d374b10f533611745a8c5/src/transformers/models/bloom/modeling_bloom.py#L86 | ||
Args: | ||
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) | ||
attention_mask (`torch.Tensor`): | ||
Token-wise attention mask, this should be of shape (batch_size, max_seq_len). | ||
num_heads (`int`, *required*): | ||
number of heads | ||
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): | ||
dtype of the output tensor | ||
""" | ||
batch_size, seq_length = attention_mask.shape | ||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) | ||
base = torch.tensor( | ||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 | ||
) | ||
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) | ||
slopes = torch.pow(base, powers) | ||
|
||
if closest_power_of_2 != num_heads: | ||
extra_base = torch.tensor( | ||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 | ||
) | ||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) | ||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) | ||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) | ||
|
||
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] | ||
alibi = slopes[..., None] * arange_tensor | ||
# originally it returns tensor of shape batch * head_index, 1, pos | ||
return alibi.reshape(batch_size, num_heads, 1, seq_length).to(dtype) | ||
|
||
# MLP Layers | ||
class MLP(nn.Module): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this a TODO? Should Alibi do something here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's no longer needed, deleted.