Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for bloom-560m model #434

Merged
merged 43 commits into from
Nov 10, 2023

Conversation

SeuperHakkerJa
Copy link
Contributor

@SeuperHakkerJa SeuperHakkerJa commented Oct 21, 2023

Description

I integrated support for the bloom-560m model, which uses alibi instead of positional encoding. Consequently, the positional_embedding_type flag can be set to 'alibi'.

Type of change

Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@@ -334,6 +334,10 @@ def input_to_embed(
# keys and queries. See HookedTransformerConfig for details
residual = embed
shortformer_pos_embed = None
#TODO: alibi embedding doesnt do anything
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a TODO? Should Alibi do something here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no longer needed, deleted.


# alibi encoding before applying causal mask
if self.cfg.positional_embedding_type == 'alibi':
#TODO: not sure about the side effect of not using standard, double check
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A reminder for myself to double check any potential side effect of setting embedding type to something other than standard, no longer needed, deleted!

if self.cfg.positional_embedding_type == 'alibi':
#TODO: not sure about the side effect of not using standard, double check
batch_size = attn_scores.size(0)
seq_len = attn_scores.size(-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be -1? Note that when generating text the attention scores are not square

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it should be set to key_length, changed to -1. Thanks!

@neelnanda-io
Copy link
Collaborator

Thanks for the PR! I left some minor comments, but it overall looks pretty good. Have you tested that this gives (approx) the same logits as Bloom on HuggingFace?

And would you be able to add a test for it to test_hooked_transformer? https://github.com/neelnanda-io/TransformerLens/blob/main/tests/acceptance/test_hooked_transformer.py

My only hesitation with the test is that 560M is large enough that it might slow things down - thoughts @alan-cooney ?

@alan-cooney
Copy link
Collaborator

My only hesitation with the test is that 560M is large enough that it might slow things down - thoughts @alan-cooney ?

Seems likely fine still. In general the tests are starting to take a bit too long however so I'll split the acceptance and unit tests up into parallel workflows.

@SeuperHakkerJa nice work on this! One other thing is the docs are failing to build (probably a formatting error in docstrings). To fix it run poetry run docs-hot-reload and then the warning with details about what went wrong will show up (hot reload should let you fix it and then instantly see the warning go away).

@SeuperHakkerJa
Copy link
Contributor Author

Thank you for your feedback and the comments. I'll begin fixing the issue soon.

@alan-cooney alan-cooney marked this pull request as draft October 22, 2023 23:28
@alan-cooney
Copy link
Collaborator

Btw I switched to draft whilst you're working on this - but feel free to switch back when ready.

Also happy to review the changes if you want Neel.

@SeuperHakkerJa SeuperHakkerJa marked this pull request as ready for review October 24, 2023 02:08
@alan-cooney alan-cooney self-requested a review October 24, 2023 12:33
Copy link
Collaborator

@alan-cooney alan-cooney left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for doing this!

It looks good but I've added some suggestions to improve the readability a bit, and also left a few questions (particuarly regarding the broadcasting onto QK). Let me know what you think.

Comment on lines 336 to 338
shortformer_pos_embed = None
elif self.cfg.positional_embedding_type == "alibi":
residual = embed
Copy link
Collaborator

@alan-cooney alan-cooney Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest we add a comment here along the lines of "ALiBi does not add positional embeddings to word embeddings, and instead it biases QK attention scores."

Maybe even link the paper i.e. https://arxiv.org/pdf/2108.12409.pdf p1.

Comment on lines 198 to 199
# bloom flags
post_embedding_ln: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please can you move the explanation to the docstring above so it's consitent

Comment on lines 620 to 625
assert alibi.shape == (
attn_scores.size(0),
attn_scores.size(1),
1,
attn_scores.size(-1),
), f"alibi shape {alibi.shape}, expecting {attn_scores.shape}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like you're testing something that should be correct as long as your code is written
correctly? If so best to keep this out of the runtime code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also is this right?

If the query shape is size 1 surelly this only works for predicting the next token, but for the
logits for all
previous tokens it would then give an incorrect answer? Would mean that training wouldn't work with
this code right (or any analysis that looks at logits other than from the last token )

Copy link
Contributor Author

@SeuperHakkerJa SeuperHakkerJa Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My apology for the oversight, I think it might not work as desired during training then. I was only considering broadcasting but not past_kv_cache. I will fix this shortly.

Comment on lines 616 to 619
# Huggingface impl uses torch.Tensor.baddbmm, with alpha = 1/sqrt(d_head), and beta=1
# and alibi.baddbmm(q,k) = beta * alibi + alpha * (q@k),
# here the `attn_scores` is already scaled by a factor of self.attn_scale,
# we only need to add alibi matrix to the result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to include this but I think it belonds in the build_alibi_tensor function instead, as part of
a broader ecxplanation of how it works?

Comment on lines 792 to 797
def build_alibi_tensor(
self,
attention_mask: torch.Tensor, # batch pos
num_heads: int,
dtype: torch.dtype,
) -> Float[torch.Tensor, "batch head_index 1 pos"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to have a decent docstring here with details about how it works, a reference and all
args etc. See https://neelnanda-io.github.io/TransformerLens/content/contributing.html#documentation
for our new contributors guide on how to do this.

Copy link
Collaborator

@alan-cooney alan-cooney Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also be good to have a unit test if possible? Feels like we can check some basic things like:

  • For each specific head, the diagonal QK values are the same (e.g. the middle diagonol is 0s)
  • For each specific head, the slope (m) is constant

if self.cfg.positional_embedding_type == "alibi":
batch_size = attn_scores.size(0)
seq_len = attn_scores.size(-1)
additive_mask = torch.ones(batch_size, seq_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small point but I think it may be clearer to move this additive_mask into build_alibi_tensor, &
then it's easier to explain (instead we can just pass the relevant sizes to that function). What do
you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it needs to have its device set if not (so that it's on the same device as QK)

@@ -757,6 +789,49 @@ def apply_rotary(

return torch.cat([x_rotated, x_pass], dim=-1)

def build_alibi_tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_attention_linear_bias or create_alibi_bias ? I'm terrible at naming things, so not the
best person to suggest here, but it feels like we shouldn't have tensor in the name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_attention_linear_bias sounds great to me. (I was naming it build_alibi_tensor only because it was named so in HF code)

transformer_lens/components.py Show resolved Hide resolved
Comment on lines 798 to 831
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
powers = torch.arange(
1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
)
slopes = torch.pow(base, powers)

if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
num_remaining_heads = min(
closest_power_of_2, num_heads - closest_power_of_2
)
extra_powers = torch.arange(
1,
1 + 2 * num_remaining_heads,
2,
device=attention_mask.device,
dtype=torch.int32,
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)

arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[
:, None, :
]
alibi = slopes[..., None] * arange_tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - worth adding a comment using p5 of https://arxiv.org/pdf/2108.12409.pdf to explain a
bit more about how this gets to the head-specific slope (m).

But can we use their general formula to simplify this a bit ("In general, for n heads, our set of slopes is the geometric sequence that starts
at 2(-8/n) and uses that same value as its ratio.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this function, I was also using HF's implementation. If you prefer, I can switch it back to the general (original?) formula.

Copy link
Collaborator

@alan-cooney alan-cooney left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for doing this!

It looks good but I've added some suggestions to improve the readability a bit, and also left a few questions (particuarly regarding the broadcasting onto QK). Let me know what you think.

@SeuperHakkerJa
Copy link
Contributor Author

Summary:

  1. Comment Cleanup & Documentation:

    • Enhanced code readability by cleaning up comments.
    • Added detailed documentation for several functions, notably compute_attention_linear_bias and expand_alibi_on_query_dim.
  2. Reimplementation of Alibi Tensor Function:

    • This new implementation ensures that the diagonal values remain consistent.
    • To achieve this, the original Alibi tensor has been expanded on the query dimension (previously set to 1) and values have been filled as suggested in the original paper.

Side Note:
I examined the original code referenced in the Alibi paper (source). It contains the following remark:

#In the next line, the part after the * is what constructs the diagonal matrix (right matrix in Figure 3 in the paper). 
#If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3, but one where all rows are identical.
#This works because the softmax operation is invariant to translation, and our bias functions are always linear. 

Thus, the 'original' code leveraged broadcasting and softmax's translation invariant properties to derive the pattern. My implementation, as explained above, will in addition ensure that hook_attn_scores correctly captures the attention scores.

Copy link
Collaborator

@alan-cooney alan-cooney left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good - one small question but otherwise good to go.

transformer_lens/loading_from_pretrained.py Outdated Show resolved Hide resolved
alan-cooney and others added 15 commits November 10, 2023 08:45
…Org#436)

Note these are also added to the makefile as this is currently the approach people use to run the tests. In the future we should probably remove this as it's better to stick to one language in the repo (and a .py script file can also do all of this).
Removes warning and speeds up poetry install.
- Organise in the order people usually want (e.g. description first)
- Remove the top image
- Add some buttons
- Fix all linting issues (e.g. should use * for bullets)
This will reduce compatibility issues with Jax
* Added santacoder to aliases

* Removed reference to multiquery parameter

* Added santacoder to tests

* Asserted that trust_remote_code=true for santacoder

* Added demo notebook for santacoder

* Removed print statements and forcibly set trust_remote_code=True

* Changed spacing and identation for black

* Removed model type hint in convert weights method

* Removed santacoder test due to memory issues

* Added back in print statement for loading pretrained model
- Organise in the order people usually want (e.g. description first)
- Remove the top image
- Add some buttons
- Fix all linting issues (e.g. should use * for bullets)
This will reduce compatibility issues with Jax
@alan-cooney alan-cooney merged commit f5a7d45 into TransformerLensOrg:main Nov 10, 2023
6 of 8 checks passed
@gegallego
Copy link

Hi! Is there any reason to not adding the other BLOOM models in this PR? Thanks!

@alan-cooney
Copy link
Collaborator

Hi! Is there any reason to not adding the other BLOOM models in this PR? Thanks!

Not aware of any reason. Happy to review a pr if you want to add the other ones

@SeuperHakkerJa
Copy link
Contributor Author

Hi! Is there any reason to not adding the other BLOOM models in this PR? Thanks!

Not aware of any reason. Happy to review a pr if you want to add the other ones

I was only concerned that the other models might be too large, potentially causing the unit tests to take too much time:

@gegallego
Copy link

Thanks for the quick response! I don't know exactly how your unit tests work... so I won't do a PR to avoid issues in that sense. But basically if I want to load a larger model I just need to list it in OFFICIAL_MODEL_NAMES and MODEL_ALIASES, right?

@SeuperHakkerJa
Copy link
Contributor Author

Thanks for the quick response! I don't know exactly how your unit tests work... so I won't do a PR to avoid issues in that sense. But basically if I want to load a larger model I just need to list it in OFFICIAL_MODEL_NAMES and MODEL_ALIASES, right?

Yes, that's right!

@alan-cooney I will then initiate this pull request to integrate the remaining bloom models (up to 7b), potentially this weekend or at some point next week

@alan-cooney
Copy link
Collaborator

Hi! Is there any reason to not adding the other BLOOM models in this PR? Thanks!

Not aware of any reason. Happy to review a pr if you want to add the other ones

I was only concerned that the other models might be too large, potentially causing the unit tests to take too much time:

Regarding the tests - I suspect they are actually fine but the easy way to run with the CPU and check they take at most a few seconds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants