-
Notifications
You must be signed in to change notification settings - Fork 319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for bloom-560m model #434
Conversation
@@ -334,6 +334,10 @@ def input_to_embed( | |||
# keys and queries. See HookedTransformerConfig for details | |||
residual = embed | |||
shortformer_pos_embed = None | |||
#TODO: alibi embedding doesnt do anything |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this a TODO? Should Alibi do something here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's no longer needed, deleted.
transformer_lens/components.py
Outdated
|
||
# alibi encoding before applying causal mask | ||
if self.cfg.positional_embedding_type == 'alibi': | ||
#TODO: not sure about the side effect of not using standard, double check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A reminder for myself to double check any potential side effect of setting embedding type to something other than standard
, no longer needed, deleted!
transformer_lens/components.py
Outdated
if self.cfg.positional_embedding_type == 'alibi': | ||
#TODO: not sure about the side effect of not using standard, double check | ||
batch_size = attn_scores.size(0) | ||
seq_len = attn_scores.size(-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be -1? Note that when generating text the attention scores are not square
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it should be set to key_length, changed to -1. Thanks!
Thanks for the PR! I left some minor comments, but it overall looks pretty good. Have you tested that this gives (approx) the same logits as Bloom on HuggingFace? And would you be able to add a test for it to test_hooked_transformer? https://github.com/neelnanda-io/TransformerLens/blob/main/tests/acceptance/test_hooked_transformer.py My only hesitation with the test is that 560M is large enough that it might slow things down - thoughts @alan-cooney ? |
Seems likely fine still. In general the tests are starting to take a bit too long however so I'll split the acceptance and unit tests up into parallel workflows. @SeuperHakkerJa nice work on this! One other thing is the docs are failing to build (probably a formatting error in docstrings). To fix it run |
Thank you for your feedback and the comments. I'll begin fixing the issue soon. |
Btw I switched to draft whilst you're working on this - but feel free to switch back when ready. Also happy to review the changes if you want Neel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for doing this!
It looks good but I've added some suggestions to improve the readability a bit, and also left a few questions (particuarly regarding the broadcasting onto QK). Let me know what you think.
shortformer_pos_embed = None | ||
elif self.cfg.positional_embedding_type == "alibi": | ||
residual = embed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest we add a comment here along the lines of "ALiBi does not add positional embeddings to word embeddings, and instead it biases QK attention scores."
Maybe even link the paper i.e. https://arxiv.org/pdf/2108.12409.pdf p1.
# bloom flags | ||
post_embedding_ln: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please can you move the explanation to the docstring above so it's consitent
transformer_lens/components.py
Outdated
assert alibi.shape == ( | ||
attn_scores.size(0), | ||
attn_scores.size(1), | ||
1, | ||
attn_scores.size(-1), | ||
), f"alibi shape {alibi.shape}, expecting {attn_scores.shape}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like you're testing something that should be correct as long as your code is written
correctly? If so best to keep this out of the runtime code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also is this right?
If the query shape is size 1 surelly this only works for predicting the next token, but for the
logits for all
previous tokens it would then give an incorrect answer? Would mean that training wouldn't work with
this code right (or any analysis that looks at logits other than from the last token )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My apology for the oversight, I think it might not work as desired during training then. I was only considering broadcasting but not past_kv_cache
. I will fix this shortly.
transformer_lens/components.py
Outdated
# Huggingface impl uses torch.Tensor.baddbmm, with alpha = 1/sqrt(d_head), and beta=1 | ||
# and alibi.baddbmm(q,k) = beta * alibi + alpha * (q@k), | ||
# here the `attn_scores` is already scaled by a factor of self.attn_scale, | ||
# we only need to add alibi matrix to the result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to include this but I think it belonds in the build_alibi_tensor
function instead, as part of
a broader ecxplanation of how it works?
transformer_lens/components.py
Outdated
def build_alibi_tensor( | ||
self, | ||
attention_mask: torch.Tensor, # batch pos | ||
num_heads: int, | ||
dtype: torch.dtype, | ||
) -> Float[torch.Tensor, "batch head_index 1 pos"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be great to have a decent docstring here with details about how it works, a reference and all
args etc. See https://neelnanda-io.github.io/TransformerLens/content/contributing.html#documentation
for our new contributors guide on how to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would also be good to have a unit test if possible? Feels like we can check some basic things like:
- For each specific head, the diagonal QK values are the same (e.g. the middle diagonol is 0s)
- For each specific head, the slope (m) is constant
transformer_lens/components.py
Outdated
if self.cfg.positional_embedding_type == "alibi": | ||
batch_size = attn_scores.size(0) | ||
seq_len = attn_scores.size(-1) | ||
additive_mask = torch.ones(batch_size, seq_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small point but I think it may be clearer to move this additive_mask
into build_alibi_tensor
, &
then it's easier to explain (instead we can just pass the relevant sizes to that function). What do
you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also it needs to have its device set if not (so that it's on the same device as QK)
transformer_lens/components.py
Outdated
@@ -757,6 +789,49 @@ def apply_rotary( | |||
|
|||
return torch.cat([x_rotated, x_pass], dim=-1) | |||
|
|||
def build_alibi_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create_attention_linear_bias
or create_alibi_bias
? I'm terrible at naming things, so not the
best person to suggest here, but it feels like we shouldn't have tensor in the name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create_attention_linear_bias
sounds great to me. (I was naming it build_alibi_tensor
only because it was named so in HF code)
transformer_lens/components.py
Outdated
batch_size, seq_length = attention_mask.shape | ||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) | ||
base = torch.tensor( | ||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), | ||
device=attention_mask.device, | ||
dtype=torch.float32, | ||
) | ||
powers = torch.arange( | ||
1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32 | ||
) | ||
slopes = torch.pow(base, powers) | ||
|
||
if closest_power_of_2 != num_heads: | ||
extra_base = torch.tensor( | ||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), | ||
device=attention_mask.device, | ||
dtype=torch.float32, | ||
) | ||
num_remaining_heads = min( | ||
closest_power_of_2, num_heads - closest_power_of_2 | ||
) | ||
extra_powers = torch.arange( | ||
1, | ||
1 + 2 * num_remaining_heads, | ||
2, | ||
device=attention_mask.device, | ||
dtype=torch.int32, | ||
) | ||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) | ||
|
||
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[ | ||
:, None, : | ||
] | ||
alibi = slopes[..., None] * arange_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense - worth adding a comment using p5 of https://arxiv.org/pdf/2108.12409.pdf to explain a
bit more about how this gets to the head-specific slope (m).
But can we use their general formula to simplify this a bit ("In general, for n heads, our set of slopes is the geometric sequence that starts
at 2(-8/n) and uses that same value as its ratio.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this function, I was also using HF's implementation. If you prefer, I can switch it back to the general (original?) formula.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for doing this!
It looks good but I've added some suggestions to improve the readability a bit, and also left a few questions (particuarly regarding the broadcasting onto QK). Let me know what you think.
Summary:
Side Note:
Thus, the 'original' code leveraged broadcasting and softmax's translation invariant properties to derive the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good - one small question but otherwise good to go.
…Org#436) Note these are also added to the makefile as this is currently the approach people use to run the tests. In the future we should probably remove this as it's better to stick to one language in the repo (and a .py script file can also do all of this).
This should add a bit of a speed boost.
Removes warning and speeds up poetry install.
- Organise in the order people usually want (e.g. description first) - Remove the top image - Add some buttons - Fix all linting issues (e.g. should use * for bullets)
This will reduce compatibility issues with Jax
* Added santacoder to aliases * Removed reference to multiquery parameter * Added santacoder to tests * Asserted that trust_remote_code=true for santacoder * Added demo notebook for santacoder * Removed print statements and forcibly set trust_remote_code=True * Changed spacing and identation for black * Removed model type hint in convert weights method * Removed santacoder test due to memory issues * Added back in print statement for loading pretrained model
- Organise in the order people usually want (e.g. description first) - Remove the top image - Add some buttons - Fix all linting issues (e.g. should use * for bullets)
This will reduce compatibility issues with Jax
Hi! Is there any reason to not adding the other BLOOM models in this PR? Thanks! |
Not aware of any reason. Happy to review a pr if you want to add the other ones |
I was only concerned that the other models might be too large, potentially causing the unit tests to take too much time: |
Thanks for the quick response! I don't know exactly how your unit tests work... so I won't do a PR to avoid issues in that sense. But basically if I want to load a larger model I just need to list it in |
Yes, that's right! @alan-cooney I will then initiate this pull request to integrate the remaining bloom models (up to 7b), potentially this weekend or at some point next week |
Regarding the tests - I suspect they are actually fine but the easy way to run with the CPU and check they take at most a few seconds. |
Description
I integrated support for the bloom-560m model, which uses alibi instead of positional encoding. Consequently, the positional_embedding_type flag can be set to 'alibi'.
Type of change
Please delete options that are not relevant.
Checklist: