Skip to content

Commit

Permalink
Add support for Bloom 560m (#434)
Browse files Browse the repository at this point in the history
  • Loading branch information
SeuperHakkerJa authored Nov 10, 2023
1 parent b46ff94 commit f5a7d45
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tests/acceptance/test_hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"gelu-2l",
"othello-gpt",
"tiny-stories-33M",
"bloom-560m",
"santacoder",
]
text = "Hello world!"
Expand All @@ -57,6 +58,7 @@
"redwood_attn_2l": 10.530948638916016,
"solu-1l": 5.256411552429199,
"tiny-stories-33M": 12.203617095947266,
"bloom-560m": 4.1953,
}

no_processing = [
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/test_compute_linear_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from transformer_lens.components import Attention


def test_create_alibi_slope():
n_ctx = 100

# Expected result computed non-vectorized way
expected = torch.zeros((n_ctx, n_ctx))
for row in range(n_ctx):
for col in range(n_ctx):
expected[row, col] = float(min(col - row, 0))

# Check against the method's vectorized version
result = Attention.create_alibi_slope(n_ctx)
assert torch.allclose(expected, result)


def test_create_alibi_bias():
n_heads = 2
n_ctx = 4

result = Attention.create_alibi_bias(n_heads, n_ctx, torch.device("cpu"))

for matrix in result:
n_row, n_col = matrix.size()
slope = -matrix[1, 0]
# Check if upper triangle is all zeros
assert torch.equal(torch.triu(matrix), torch.zeros_like(matrix))

ref_lower_triangle = torch.zeros_like(matrix)
for i in range(1, n_row):
for j in range(i):
ref_lower_triangle[i, j] = -slope * (i - j)

# Check if the lower triangle is decreasing by a constant slope (towards the bottom left corner).
assert torch.equal(
torch.tril(matrix, diagonal=-1), torch.tril(ref_lower_triangle, diagonal=-1)
)
4 changes: 4 additions & 0 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ def input_to_embed(
# keys and queries. See HookedTransformerConfig for details
residual = embed
shortformer_pos_embed = None
elif self.cfg.positional_embedding_type == "alibi":
# ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores.
residual = embed
shortformer_pos_embed = None
else:
raise ValueError(
f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class HookedTransformerConfig:
tokenizer_prepends_bos (bool, *optional*): This flag is set by set_tokenizer. It is set to True only
when the tokenizer automatically prepends the BOS token if initialized with add_bos_token=True.
We need this information to dynamically control bos prepending.
post_embedding_ln (bool): Whether to apply layer normalization after embedding the tokens. Defaults
to False.
"""

n_layers: int
Expand Down Expand Up @@ -194,6 +196,7 @@ class HookedTransformerConfig:
default_prepend_bos: bool = True
dtype: torch.dtype = torch.float32
tokenizer_prepends_bos: Optional[bool] = None
post_embedding_ln: bool = False

def __post_init__(self):
if self.n_heads == -1:
Expand Down
156 changes: 155 additions & 1 deletion transformer_lens/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,17 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter(
torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=cfg.dtype)
)
# Some models (e.g. Bloom) need post embedding layer norm
if cfg.post_embedding_ln:
self.ln = LayerNorm(cfg)

def forward(
self, tokens: Int[torch.Tensor, "batch pos"]
) -> Float[torch.Tensor, "batch pos d_model"]:
# If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
# B acts as a tensor of indices into the second dimension (so >=0 and <b)
if self.cfg.post_embedding_ln:
return self.ln(self.W_E[tokens, :])
return self.W_E[tokens, :]


Expand Down Expand Up @@ -478,6 +483,10 @@ def __init__(
)
self.register_buffer("rotary_sin", sin)
self.register_buffer("rotary_cos", cos)
elif self.cfg.positional_embedding_type == "alibi":
# ALiBi bias wil be constructed on the first forward pass.
# Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
self.alibi = None

@property
def OV(self) -> FactoredMatrix:
Expand Down Expand Up @@ -533,7 +542,6 @@ def forward(
qkv_einops_string = "batch pos head_index d_model"
else:
qkv_einops_string = "batch pos d_model"

q = self.hook_q(
einsum(
f"{qkv_einops_string}, head_index d_model d_head \
Expand Down Expand Up @@ -593,6 +601,22 @@ def forward(
)
/ self.attn_scale
) # [batch, head_index, query_pos, key_pos]

if self.cfg.positional_embedding_type == "alibi":
query_ctx = attn_scores.size(-2)
# The key context length is the number of positions in the past - this includes all positions in the cache
key_ctx = attn_scores.size(-1)

# only recompute when necessary to increase efficiency.
if self.alibi is None or key_ctx > self.alibi.size(-1):
self.alibi = Attention.create_alibi_bias(
self.cfg.n_heads, key_ctx, self.cfg.device
)

attn_scores += self.alibi[
:, :query_ctx, :key_ctx
] # [batch, head_index, query_pos, key_pos]

if self.cfg.attention_dir == "causal":
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
attn_scores = self.apply_causal_mask(
Expand Down Expand Up @@ -757,6 +781,136 @@ def apply_rotary(

return torch.cat([x_rotated, x_pass], dim=-1)

@staticmethod
def create_alibi_slope(
n_ctx: int, device: torch.device = None
) -> Float[torch.Tensor, "query key"]:
"""Create an ALiBi Slope Matrix.
Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.
See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
Examples:
>>> Attention.create_alibi_slope(3)
tensor([[ 0., 0., 0.],
[-1., 0., 0.],
[-2., -1., 0.]])
>>> Attention.create_alibi_slope(4)
tensor([[ 0., 0., 0., 0.],
[-1., 0., 0., 0.],
[-2., -1., 0., 0.],
[-3., -2., -1., 0.]])
Args:
n_ctx: The maximum number of tokens in a prompt.
Returns:
A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower
triangle is decreasing by a constant slope of 1 (towards the bottom left corner).
"""
# set rows as [[0,1,2...]]
rows = torch.arange(n_ctx, device=device).unsqueeze(0)

# Set cols as [[0],[1],[2]...]
cols = torch.arange(n_ctx, device=device).unsqueeze(1)

# Use broadcasting to create the desired lower triangular part of the matrix
slope_matrix = rows - cols

# Use the clamp method to set all positive values (upper right triangle) to
return slope_matrix.clamp(max=0).to(torch.float32)

@staticmethod
def create_alibi_multipliers(
n_heads: int, device: torch.device = None
) -> Float[torch.Tensor, "head_idx"]:
"""Create the ALiBi Scalar Multipliers for each Head.
For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and
uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1),
1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)].
See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
Examples:
>>> Attention.create_alibi_multipliers(8)
tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
>>> Attention.create_alibi_multipliers(16)
tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
Args:
n_heads: The number of heads in a layer.
device: The device to create the tensor on.
Returns:
A tensor of shape (n_heads,) containing the scalar multiplier for each head.
"""
# Calculate the starting value
start = 2 ** (-8 / n_heads)

# Generate the indices [0, 1, ..., n_heads-1]
indices = torch.arange(n_heads, device=device)

# Compute the multipliers, with the starting value being the same as the ratio
multipliers = start * (start**indices)

return multipliers

@staticmethod
def create_alibi_bias(
n_heads: int, n_ctx: int, device: torch.device = None
) -> Float[torch.Tensor, "head_idx query key"]:
"""Create the ALiBi Bias for all Heads.
Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.
The broad idea behind ALiBi is to remove the positional encoding from the original transformer
model, and instead apply a bias to each attention score. This bias is proportional to the
distance between the query and key (i.e. it encourage paying less attention to more distant
tokens), and is added to the attention scores before the softmax. It is used in models such as
Bloom.
Examples:
>>> Attention.create_alibi_bias(2, 4, torch.device('cpu'))
tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],
[-0.0625, 0.0000, 0.0000, 0.0000],
[-0.1250, -0.0625, 0.0000, 0.0000],
[-0.1875, -0.1250, -0.0625, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, 0.0000],
[-0.0039, 0.0000, 0.0000, 0.0000],
[-0.0078, -0.0039, 0.0000, 0.0000],
[-0.0117, -0.0078, -0.0039, 0.0000]]])
Args:
n_heads: The number of heads in a layer.
n_ctx: The maximum number of tokens in a prompt.
device: The device to create the tensor on.
Returns:
The ALiBi bias that should be added to the attention scores before the softmax.
"""
# Create the slope matrix
slope: Float[torch.Tensor, "query key"] = Attention.create_alibi_slope(
n_ctx, device
)

# Create the scalar multiplier for each head.
multipliers: Float[
torch.Tensor, "head_idx"
] = Attention.create_alibi_multipliers(n_heads, device)

# The ALiBi bias is then m * slope_matrix
alibi_bias = torch.einsum("ij,k->kij", slope, multipliers)

return alibi_bias


# MLP Layers
class MLP(nn.Module):
Expand Down
Loading

0 comments on commit f5a7d45

Please sign in to comment.