Skip to content

Commit

Permalink
Add a function to convert nanogpt weights (#475)
Browse files Browse the repository at this point in the history
* Add a function to convert nanogpt weights

* Remove need for bias parameter
  • Loading branch information
adamkarvonen authored Jan 16, 2024
1 parent 7241042 commit 5754a0b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"convert_gptj_weights",
"convert_llama_weights",
"convert_mingpt_weights",
"convert_nanogpt_weights",
"convert_neel_solu_old_weights",
"convert_neo_weights",
"convert_neox_weights",
Expand Down
101 changes: 101 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,107 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig):
return state_dict


def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig):
"""For https://github.com/karpathy/nanoGPT
There are two complications with converting nanogpt models:
The first is that some state dicts have an unwanted prefix on keys that needs to be removed.
The second is that the models can be saved with or without bias. By default, there
is no bias. This function can handle both cases."""
# Nanogpt models saved after torch.compile() have this unwanted prefix
# This is a simple way to remove it
unwanted_prefix = "_orig_mod."
for k, v in list(old_state_dict.items()):
if k.startswith(unwanted_prefix):
old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k)

new_state_dict = {}
new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"]
new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"]

new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"]
new_state_dict["ln_final.b"] = torch.zeros_like(
old_state_dict["transformer.ln_f.weight"]
)
new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T

bias = False
if "transformer.ln_f.bias" in old_state_dict:
bias = True
new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"]

for layer in range(cfg.n_layers):
layer_key = f"transformer.h.{layer}"

new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[
f"{layer_key}.ln_1.weight"
]
# A bias of zeros is required for folding layer norm
new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like(
old_state_dict[f"{layer_key}.ln_1.weight"]
)
new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[
f"{layer_key}.ln_2.weight"
]
new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like(
old_state_dict[f"{layer_key}.ln_2.weight"]
)

W = old_state_dict[f"{layer_key}.attn.c_attn.weight"]
W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0)
W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q
new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K
new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V

W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"]
W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O

new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[
f"{layer_key}.mlp.c_fc.weight"
].T
new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[
f"{layer_key}.mlp.c_proj.weight"
].T

if bias:
new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[
f"{layer_key}.ln_1.bias"
]
new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[
f"{layer_key}.ln_2.bias"
]
new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[
f"{layer_key}.mlp.c_fc.bias"
]
new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[
f"{layer_key}.mlp.c_proj.bias"
]

B = old_state_dict[f"{layer_key}.attn.c_attn.bias"]
B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0)
B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads)
B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads)
B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads)
new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q
new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K
new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V
new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[
f"{layer_key}.attn.c_proj.bias"
]

new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[
f"{layer_key}.mlp.c_fc.bias"
].T
new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[
f"{layer_key}.mlp.c_proj.bias"
].T

return new_state_dict


def convert_bert_weights(bert, cfg: HookedTransformerConfig):
embeddings = bert.bert.embeddings
state_dict = {
Expand Down

0 comments on commit 5754a0b

Please sign in to comment.