From b892759ce21e150b40917b2d3c2d2eb7eaa78e22 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 7 Dec 2024 08:48:59 -0500 Subject: [PATCH 01/18] add Llama-3.3-70B-Instruct (#1859) --- README.md | 1 + litgpt/config.py | 23 +++++++++++++++++++++++ tests/test_model.py | 1 + tutorials/download_model_weights.md | 2 ++ 4 files changed, 27 insertions(+) diff --git a/README.md b/README.md index 3856a332ea..f9d401248c 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,7 @@ Every model is written from scratch to maximize performance and remove layers of | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | | Llama 3.1 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/) | +| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://ai.meta.com/blog/llama-3-3-large-language-model-family/) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | | MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) | | Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | diff --git a/litgpt/config.py b/litgpt/config.py index 684f3f78be..47bdc384d2 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -700,8 +700,31 @@ def norm_class(self) -> Type: rope_base=500000, rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) ), + # https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct/blob/main/config.json + dict( + name="Llama-3.3-70B-Instruct", + hf_config=dict(org="meta-llama", name="Llama-3.3-70B-Instruct"), + block_size=131072, + vocab_size=128000, + padded_vocab_size=128256, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=28672, + rope_base=500000, + rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) + ), ] for c in llama_3: + if c["name"] == "Llama-3.3-70B-Instruct": + configs.append(c) + continue for kind in ("", "-Instruct"): copy = deepcopy(c) copy["name"] = c["name"].format(kind) diff --git a/tests/test_model.py b/tests/test_model.py index 3ca5e80599..34e15ed0d3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -223,6 +223,7 @@ def test_against_original_open_llama_3b(device, dtype): {"name": "Llama-3.1-8B-Instruct"}, {"name": "Llama-3.2-1B"}, {"name": "Llama-3.2-3B"}, + {"name": "Llama-3.3-70B-Instruct"}, ], ) @pytest.mark.parametrize( diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 509218ac96..39631b6486 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -20,6 +20,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.1 | 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md) | +| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) | | Llama 3.1 Nemotron | 70B | NVIDIA | [NVIDIA AI 2024](https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct/modelcard) | | LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | @@ -134,6 +135,7 @@ meta-llama/Llama-3.2-1B meta-llama/Llama-3.2-1B-Instruct meta-llama/Llama-3.2-3B meta-llama/Llama-3.2-3B-Instruct +meta-llama/Llama-3.3-70B-Instruct meta-llama/Meta-Llama-3-70B meta-llama/Meta-Llama-3-70B-Instruct meta-llama/Meta-Llama-3-8B From 7865e8a019bfeb2b0172ffe0963afa287c6a3043 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 7 Dec 2024 09:42:43 -0500 Subject: [PATCH 02/18] add Salamandra (#1857) --- README.md | 1 + litgpt/config.py | 55 ++++++++++++++++++++++++++ litgpt/prompts.py | 8 ++++ litgpt/tokenizer.py | 3 ++ tests/test_model.py | 60 +++++++++++++++++++++++++++++ tutorials/download_model_weights.md | 5 +++ 6 files changed, 132 insertions(+) diff --git a/README.md b/README.md index f9d401248c..ad5dd03828 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ Every model is written from scratch to maximize performance and remove layers of | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | +| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | diff --git a/litgpt/config.py b/litgpt/config.py index 47bdc384d2..28d79474d9 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2066,4 +2066,59 @@ def norm_class(self) -> Type: configs.extend(qwq) +############# +# Salamandra +############# + +salamandra = [ + # https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json + dict( + name="salamandra-2b{}", + hf_config=dict(org="BSC-LT", name="salamandra-2b{}"), + block_size=8192, + vocab_size=256000, + padded_vocab_size=256000, + n_layer=24, + n_head=16, + n_embd=2048, + n_query_groups=16, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=5440, + norm_eps=1e-5, + rope_base=10000 + ), + # https://huggingface.co/BSC-LT/salamandra-7b-instruct/blob/main/config.json + dict( + name="salamandra-7b{}", + hf_config=dict(org="BSC-LT", name="salamandra-7b{}"), + block_size=8192, + vocab_size=256000, + padded_vocab_size=256000, + n_layer=32, + n_head=32, + n_embd=4096, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=11008, + norm_eps=1e-6, + rope_base=10000 + ), +] + +for c in salamandra: + for kind in ("", "-instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + + name_to_config = {config["name"]: config for config in configs} diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 5f5fd14494..96a99073b6 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -290,6 +290,11 @@ def apply(self, prompt: str, **kwargs: str) -> str: system_message = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class Salamandra(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + system_message = "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit. La meva base de coneixement es va actualitzar per última vegada l'agost de 2023." + return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + # Maps prompt style names to PromptStyle classes prompt_styles: Dict[str, Type[PromptStyle]] = { @@ -316,6 +321,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "olmo": OLMo, "qwen2.5": Qwen2_5, "qwq": QwQ, + "salamandra": Salamandra, } @@ -358,6 +364,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Qwen2_5() if re.search(r"QwQ-.*", model_name): return QwQ() + if re.search(r"salamandra-.*-instruct", model_name): + return Salamandra() return Default() diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index a81c59aa2d..6018a44734 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -143,6 +143,9 @@ def decode(self, tensor: torch.Tensor) -> str: if len(tokens) == 1 and self.apply_decoding_fix: dummy_token_id = 33 # \x1e dummy_token = self.processor.decode([dummy_token_id]) + if dummy_token != "\x1e": + dummy_token_id = 165 # \x1e is different in salamandra tokenizers + dummy_token = self.processor.decode([dummy_token_id]) return self.processor.decode([dummy_token_id] + tokens)[len(dummy_token) :] return self.processor.decode(tokens) diff --git a/tests/test_model.py b/tests/test_model.py index 34e15ed0d3..b7da3750b7 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -852,6 +852,66 @@ def test_against_original_qwen_2_5(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_salamandra(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + n_query_groups=2, + intermediate_size=86, + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @RunIf(dynamo=True) @torch.inference_mode() def test_model_compile(): diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 39631b6486..96316f0f0b 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -40,6 +40,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | @@ -63,6 +64,10 @@ The output is shown below: allenai/OLMo-1B-hf allenai/OLMo-7B-hf allenai/OLMo-7B-Instruct-hf +bsc-lt/salamandra-2b +bsc-lt/salamandra-2b-instruct +bsc-lt/salamandra-7b +bsc-lt/salamandra-7b-instruct codellama/CodeLlama-13b-hf codellama/CodeLlama-13b-Instruct-hf codellama/CodeLlama-13b-Python-hf From 9750eb684d24ebf1fda48968e250a5dffc82ee33 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 7 Dec 2024 09:45:31 -0500 Subject: [PATCH 03/18] Qwen2.5: fix block size for Coder series (#1856) --- litgpt/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index 28d79474d9..d0f5e52a4e 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -1928,7 +1928,7 @@ def norm_class(self) -> Type: dict( name="Qwen2.5-Coder-1.5B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-1.5B{}"), - block_size=131072, + block_size=32768, vocab_size=151643, padded_vocab_size=151936, n_layer=28, @@ -1970,7 +1970,7 @@ def norm_class(self) -> Type: dict( name="Qwen2.5-Coder-7B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-7B{}"), - block_size=131072, + block_size=32768, vocab_size=151643, padded_vocab_size=152064, n_layer=28, @@ -1991,7 +1991,7 @@ def norm_class(self) -> Type: dict( name="Qwen2.5-Coder-14B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-14B{}"), - block_size=131072, + block_size=32768, vocab_size=151643, padded_vocab_size=152064, n_layer=48, @@ -2012,7 +2012,7 @@ def norm_class(self) -> Type: dict( name="Qwen2.5-Coder-32B{}", hf_config=dict(org="Qwen", name="Qwen2.5-Coder-32B{}"), - block_size=131072, + block_size=32768, vocab_size=151643, padded_vocab_size=152064, n_layer=64, From 34930ba9fe491516b16369be01da1a9ae9fac36f Mon Sep 17 00:00:00 2001 From: Yunfeng Wang Date: Sun, 8 Dec 2024 20:51:53 +0800 Subject: [PATCH 04/18] fix: add missing"," (#1855) Co-authored-by: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> --- litgpt/data/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/data/__init__.py b/litgpt/data/__init__.py index 0f0f3d6cf9..b6d7275e5e 100644 --- a/litgpt/data/__init__.py +++ b/litgpt/data/__init__.py @@ -33,6 +33,6 @@ "TextFiles", "TinyLlama", "TinyStories", - "MicroLlama" + "MicroLlama", "get_sft_collate_fn", ] From a0f5bd32eee0c464fbe2abc3650fe07b77d76653 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 8 Dec 2024 09:26:48 -0500 Subject: [PATCH 05/18] fix llama3.3 readme url (#1862) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ad5dd03828..533872cc91 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ Every model is written from scratch to maximize performance and remove layers of | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | | Llama 3.1 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/) | -| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://ai.meta.com/blog/llama-3-3-large-language-model-family/) | +| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | | MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) | | Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | From 4b3dd3bbc3199e1bd60847e14f445762db8de806 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Sun, 15 Dec 2024 15:14:12 +0300 Subject: [PATCH 06/18] Set torch.load(..., `weights_only=False`) in litgpt/api.py (#1874) --- litgpt/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litgpt/api.py b/litgpt/api.py index a114fdd512..ea156ce600 100644 --- a/litgpt/api.py +++ b/litgpt/api.py @@ -386,7 +386,7 @@ def distribute( model.eval() if generate_strategy == "sequential": - state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu") + state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False) model.load_state_dict(state_dict, assign=True) model = fabric.setup_module(model, move_to_device=False) @@ -405,7 +405,7 @@ def distribute( pbar = tqdm(total=fabric.world_size, desc="Loading model weights") for rank in range(fabric.world_size): if fabric.global_rank == rank: - state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu") + state_dict = torch.load(str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False) model.load_state_dict(state_dict, assign=True) # cannot use `.setup_module` because it will wrap with DDP From 972dee40725e0f67588e229a4d8ec882b2f47a9c Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Dec 2024 08:19:22 -0500 Subject: [PATCH 07/18] Add Qwen2.5 math (#1863) Co-authored-by: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> --- README.md | 1 + litgpt/config.py | 68 ++++++++++++++++++++++++++++ litgpt/prompts.py | 7 +++ tests/test_convert_lit_checkpoint.py | 2 +- tests/test_model.py | 2 +- tutorials/download_model_weights.md | 7 +++ 6 files changed, 85 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 533872cc91..3dcf6b1f4e 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,7 @@ Every model is written from scratch to maximize performance and remove layers of | Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | +| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | diff --git a/litgpt/config.py b/litgpt/config.py index d0f5e52a4e..54420826bb 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2033,6 +2033,74 @@ def norm_class(self) -> Type: qwen_2_5.extend(qwen_2_5_coder) +qwen_2_5_math = [ + # https://huggingface.co/Qwen/Qwen2.5-Math-1.5B/blob/main/config.json + dict( + name="Qwen2.5-Math-1.5B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Math-1.5B{}"), + block_size=4096, + vocab_size=151643, + padded_vocab_size=151936, + n_layer=28, + n_head=12, + n_embd=1536, + n_query_groups=2, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8960, + norm_eps=1e-6, + rope_base=10000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Math-7B/blob/main/config.json + dict( + name="Qwen2.5-Math-7B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Math-7B{}"), + block_size=4096, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=28, + n_head=28, + n_embd=3584, + n_query_groups=4, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=18944, + norm_eps=1e-6, + rope_base=10000 + ), + # https://huggingface.co/Qwen/Qwen2.5-Math-72B/blob/main/config.json + dict( + name="Qwen2.5-Math-72B{}", + hf_config=dict(org="Qwen", name="Qwen2.5-Math-72B{}"), + block_size=4096, + vocab_size=151643, + padded_vocab_size=152064, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + attn_bias=True, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=29568, + norm_eps=1e-5, + rope_base=10000 + ), +] + +qwen_2_5.extend(qwen_2_5_math) + for c in qwen_2_5: for kind in ("", "-Instruct"): copy = deepcopy(c) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 96a99073b6..51426e1523 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -284,6 +284,10 @@ def apply(self, prompt: str, **kwargs: str) -> str: system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class Qwen2_5_Math(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + system_message = "Please reason step by step, and put your final answer within \\boxed{}." + return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" class QwQ(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: @@ -320,6 +324,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "llama3": Llama3, "olmo": OLMo, "qwen2.5": Qwen2_5, + "qwen2.5-math": Qwen2_5_Math, "qwq": QwQ, "salamandra": Salamandra, } @@ -360,6 +365,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Gemma() if re.search(r"OLMo.*-hf", model_name): return OLMo() + if re.search(r"Qwen2\.5-Math-.*", model_name): + return Qwen2_5_Math() if re.search(r"Qwen2\.5-.*", model_name): return Qwen2_5() if re.search(r"QwQ-.*", model_name): diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index 5e24827cef..5809f0063d 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -524,7 +524,7 @@ def test_check_conversion_supported_lora(): check_conversion_supported(lit_weights=lit_weights) @torch.inference_mode() -@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "QwQ-32B-Preview")) +@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview")) @pytest.mark.parametrize( ("device", "dtype"), [ diff --git a/tests/test_model.py b/tests/test_model.py index b7da3750b7..39e92f1204 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -792,7 +792,7 @@ def test_against_original_gemma_2(model_name, device, dtype): @torch.inference_mode() -@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "QwQ-32B-Preview")) +@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "Qwen2.5-Math-1.5B", "QwQ-32B-Preview")) @pytest.mark.parametrize( ("device", "dtype"), [ diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 96316f0f0b..49d3a6b619 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -37,6 +37,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | +| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | @@ -200,6 +201,12 @@ Qwen/Qwen2.5-Coder-14B Qwen/Qwen2.5-Coder-14B-Instruct Qwen/Qwen2.5-Coder-32B Qwen/Qwen2.5-Coder-32B-Instruct +Qwen/Qwen2.5-Math-1.5B +Qwen/Qwen2.5-Math-1.5B-Instruct +Qwen/Qwen2.5-Math-7B +Qwen/Qwen2.5-Math-7B-Instruct +Qwen/Qwen2.5-Math-72B +Qwen/Qwen2.5-Math-72B-Instruct Qwen/QwQ-32B-Preview stabilityai/FreeWilly2 stabilityai/stable-code-3b From 7b26d358134011b39217aa9ad09bf9a005e06201 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 16 Dec 2024 04:22:57 -0500 Subject: [PATCH 08/18] Add SmolLM2 (#1848) Co-authored-by: Andrei-Aksionov Co-authored-by: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> --- README.md | 1 + litgpt/config.py | 76 ++++++++++++++++++++++++++++- litgpt/prompts.py | 9 ++++ litgpt/scripts/download.py | 2 +- litgpt/tokenizer.py | 4 +- tests/test_model.py | 61 +++++++++++++++++++++++ tutorials/download_model_weights.md | 7 +++ 7 files changed, 157 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3dcf6b1f4e..e12368fb7d 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,7 @@ Every model is written from scratch to maximize performance and remove layers of | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | +| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | diff --git a/litgpt/config.py b/litgpt/config.py index 54420826bb..475f017e50 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2134,10 +2134,10 @@ def norm_class(self) -> Type: configs.extend(qwq) + ############# # Salamandra ############# - salamandra = [ # https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json dict( @@ -2189,4 +2189,78 @@ def norm_class(self) -> Type: configs.append(copy) +############### +# SmolLM2 +############### +smollm2 = [ + # https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json + dict( + name="SmolLM2-135M{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-135M{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=30, + n_head=9, + n_embd=576, + n_query_groups=3, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=1536, + rope_base=100000, + norm_eps=1e-5, + ), + # https://huggingface.co/HuggingFaceTB/SmolLM2-360M/blob/main/config.json + dict( + name="SmolLM2-360M{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-360M{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=32, + n_head=15, + n_embd=960, + n_query_groups=5, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=2560, + rope_base=100000, + norm_eps=1e-5, + ), + # https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B/blob/main/config.json + dict( + name="SmolLM2-1.7B{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-1.7B{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=24, + n_head=32, + n_embd=2048, + n_query_groups=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8192, + rope_base=130000, + norm_eps=1e-5, + ), +] + +for c in smollm2: + for kind in ("", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + + name_to_config = {config["name"]: config for config in configs} diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 51426e1523..09b3277c7d 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -300,6 +300,12 @@ def apply(self, prompt: str, **kwargs: str) -> str: return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class SmolLM2(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face" + return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + # Maps prompt style names to PromptStyle classes prompt_styles: Dict[str, Type[PromptStyle]] = { # Dataset-specific prompt styles @@ -326,6 +332,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "qwen2.5": Qwen2_5, "qwen2.5-math": Qwen2_5_Math, "qwq": QwQ, + "smollm2": SmolLM2, "salamandra": Salamandra, } @@ -371,6 +378,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Qwen2_5() if re.search(r"QwQ-.*", model_name): return QwQ() + if re.search(r"SmolLM2.*-Instruct", model_name): + return SmolLM2() if re.search(r"salamandra-.*-instruct", model_name): return Salamandra() return Default() diff --git a/litgpt/scripts/download.py b/litgpt/scripts/download.py index c1af2af133..fc6c153fad 100644 --- a/litgpt/scripts/download.py +++ b/litgpt/scripts/download.py @@ -131,7 +131,7 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s with gated_repo_catcher(repo_id, access_token): info = repo_info(repo_id, token=access_token) filenames = [f.rfilename for f in info.siblings] - bins = list(filter_repo_objects(items=filenames, allow_patterns=["*.bin*"])) + bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"])) safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"])) return bins, safetensors diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index 6018a44734..10f7d031f6 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -87,7 +87,7 @@ def token_to_id(self, token: str) -> int: raise ValueError(f"token {token!r} not found in the collection.") return id_ - def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: + def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file(): return False with open(tokenizer_config_path, encoding="utf-8") as fp: @@ -96,6 +96,8 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: # `PreTrainedTokenizerFast` if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")): return True + if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.name.endswith("Instruct"): + return True if "add_bos_token" in config: return config["add_bos_token"] # if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True. diff --git a/tests/test_model.py b/tests/test_model.py index 39e92f1204..89e926d173 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -852,6 +852,7 @@ def test_against_original_qwen_2_5(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b")) @pytest.mark.parametrize( @@ -910,6 +911,66 @@ def test_against_original_salamandra(model_name, device, dtype): ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + + +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_smollm2(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + n_query_groups=2, + intermediate_size=86, + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) @RunIf(dynamo=True) diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 49d3a6b619..876db1916a 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -40,6 +40,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | +| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | @@ -122,6 +123,12 @@ google/gemma-2b-it google/gemma-7b google/gemma-7b-it h2oai/h2o-danube2-1.8b-chat +HuggingFaceTB/SmolLM2-135M +HuggingFaceTB/SmolLM2-135M-Instruct +HuggingFaceTB/SmolLM2-360M +HuggingFaceTB/SmolLM2-360M-Instruct +HuggingFaceTB/SmolLM2-1.7B +HuggingFaceTB/SmolLM2-1.7B-Instruct lmsys/longchat-13b-16k lmsys/longchat-7b-16k lmsys/vicuna-13b-v1.3 From 7e12d6475178deb4e289af3f1369820e2b1ac479 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 21 Dec 2024 05:36:33 -0500 Subject: [PATCH 09/18] Add Mistral-Large-Instruct-2411 (#1876) --- litgpt/config.py | 20 ++++++++++++++++++++ tutorials/download_model_weights.md | 1 + 2 files changed, 21 insertions(+) diff --git a/litgpt/config.py b/litgpt/config.py index 475f017e50..577a2f3335 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -1663,6 +1663,26 @@ def norm_class(self) -> Type: intermediate_size=28672, ) ) +configs.append( + # https://huggingface.co/mistralai/Mistral-Large-Instruct-2411/blob/main/config.json + dict( + name="Mistral-Large-Instruct-2411", + hf_config=dict(org="mistralai", name="Mistral-Large-Instruct-2411"), + padded_vocab_size=32768, + block_size=32768, + n_layer=88, + n_head=96, + n_embd=12288, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + norm_eps=1e-05, + mlp_class_name="LLaMAMLP", + intermediate_size=28672, + ) +) ############ diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 876db1916a..bd46a3564f 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -171,6 +171,7 @@ mistralai/Mistral-7B-Instruct-v0.3 mistralai/Mistral-7B-v0.1 mistralai/Mistral-7B-v0.3 mistralai/Mistral-Large-Instruct-2407 +mistralai/Mistral-Large-Instruct-2411 mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mixtral-8x7B-v0.1 mistralai/Mixtral-8x22B-Instruct-v0.1 From bb8e0dacfb9236c79d13ac6fbd24dc3a89094f95 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:17:47 +0300 Subject: [PATCH 10/18] Bump version for 0.5.4 release (#1883) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 75b0c808b2..e2aa18f04f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "litgpt" -version = "0.5.4.dev1" +version = "0.5.4" description = "Hackable implementation of state-of-the-art open-source LLMs" authors = [ { name = "Lightning AI", email = "contact@lightning.ai" }, From e63099c460d9e8cccf48c41476671f1f035b5b10 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:57:17 +0300 Subject: [PATCH 11/18] Temporary remove Thunder to make a release (#1884) --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e2aa18f04f..cdc4fac87f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,6 @@ dependencies = [ "safetensors>=0.4.3", # download models "tokenizers>=0.15.2", # tokenization in most models "tqdm>=4.66.0", # convert_hf_checkpoint - "lightning-thunder @ git+https://github.com/Lightning-AI/lightning-thunder/ ; python_version >= '3.10' and sys_platform == 'linux'", ] [project.urls] From fe96c6366ad8fd20a632673bd6f344bd9b18ca04 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Mon, 23 Dec 2024 16:20:16 +0300 Subject: [PATCH 12/18] Post-release setup for 0.5.5.dev1 (#1885) --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cdc4fac87f..1dd1e53743 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "litgpt" -version = "0.5.4" +version = "0.5.5.dev1" description = "Hackable implementation of state-of-the-art open-source LLMs" authors = [ { name = "Lightning AI", email = "contact@lightning.ai" }, @@ -17,6 +17,7 @@ dependencies = [ "safetensors>=0.4.3", # download models "tokenizers>=0.15.2", # tokenization in most models "tqdm>=4.66.0", # convert_hf_checkpoint + "lightning-thunder @ git+https://github.com/Lightning-AI/lightning-thunder/ ; python_version >= '3.10' and sys_platform == 'linux'", ] [project.urls] From 1811ecc5e7b872b4fe277b74bc899f558b01ca4c Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 23 Dec 2024 12:22:38 -0500 Subject: [PATCH 13/18] Falcon3 (#1881) --- README.md | 1 + litgpt/config.py | 89 +++++++++++++++++++++++++++++ litgpt/prompts.py | 13 +++++ tests/test_model.py | 59 +++++++++++++++++++ tests/test_tokenizer.py | 5 ++ tutorials/download_model_weights.md | 9 +++ 6 files changed, 176 insertions(+) diff --git a/README.md b/README.md index e12368fb7d..c58a586fdc 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,7 @@ Every model is written from scratch to maximize performance and remove layers of | CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | +| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) | | FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | | Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | | Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | diff --git a/litgpt/config.py b/litgpt/config.py index 577a2f3335..af4098c2d3 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -441,6 +441,95 @@ def norm_class(self) -> Type: copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind) configs.append(copy) +falcon3 = [ + # https://huggingface.co/tiiuae/Falcon3-1B-Base/blob/main/config.json + dict( + name="Falcon3-1B{}", + hf_config=dict(org="tiiuae", name="Falcon3-1B{}"), + block_size=4096, + vocab_size=131072, + padded_vocab_size=131072, + n_layer=18, + n_head=8, + n_query_groups=4, + n_embd=2048, + rotary_percentage=1.0, + parallel_residual=False, + rope_base=1000042, + norm_eps=1e-6, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8192, + ), + # https://huggingface.co/tiiuae/Falcon3-3B-Base/blob/main/config.json + dict( + name="Falcon3-3B{}", + hf_config=dict(org="tiiuae", name="Falcon3-3B{}"), + block_size=32768, + vocab_size=131072, + padded_vocab_size=131072, + n_layer=22, + n_head=12, + n_query_groups=4, + n_embd=3072, + rotary_percentage=1.0, + parallel_residual=False, + rope_base=1000042, + norm_eps=1e-6, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=9216, + ), + # https://huggingface.co/tiiuae/Falcon3-7B-Base/blob/main/config.json + dict( + name="Falcon3-7B{}", + hf_config=dict(org="tiiuae", name="Falcon3-7B{}"), + block_size=32768, + vocab_size=131072, + padded_vocab_size=131072, + n_layer=28, + n_head=12, + n_query_groups=4, + n_embd=3072, + rotary_percentage=1.0, + parallel_residual=False, + rope_base=1000042, + norm_eps=1e-6, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=23040, + ), + # https://huggingface.co/tiiuae/Falcon3-10B-Base/blob/main/config.json + dict( + name="Falcon3-10B{}", + hf_config=dict(org="tiiuae", name="Falcon3-10B{}"), + block_size=32768, + vocab_size=131072, + padded_vocab_size=131072, + n_layer=40, + n_head=12, + n_query_groups=4, + n_embd=3072, + rotary_percentage=1.0, + parallel_residual=False, + rope_base=1000042, + norm_eps=1e-6, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=23040, + ), +] +for c in falcon3: + for kind in ("-Base", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + ############################# # OpenLM Research Open LLaMA diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 09b3277c7d..29cf7a4249 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -112,6 +112,17 @@ def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: ) +class Falcon3(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"<|user|>\n{prompt}<|endoftext|>\n<|assistant|>\n" + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + return ( + [tokenizer.eos_id], + [tokenizer.token_to_id("<|endoftext|>")], + ) + + class Llama2FunctionCalling(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: # Has to be before the llama config @@ -344,6 +355,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return StableLMZephyr() if re.search("stablecode-instruct", model_name): return StableCode() + if re.search(r"Falcon3.*-Instruct", model_name): + return Falcon3() if re.search(r"falcon.*-instruct", model_name): return Falcon() if re.search("Llama-2-7b-chat-hf-function-calling-v2", model_name): diff --git a/tests/test_model.py b/tests/test_model.py index 89e926d173..9a21f0d34d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -972,6 +972,65 @@ def test_against_original_smollm2(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("Falcon3-1B-Base", "Falcon3-7B-Base")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_hf_falcon3(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + n_query_groups=2, + intermediate_size=86, + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + @RunIf(dynamo=True) @torch.inference_mode() diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 43d440642c..d5c7d12699 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -60,6 +60,11 @@ def test_tokenizer_against_hf(config): # even though their config defines it, it's set as None in HF assert isinstance(ours.bos_id, int) assert theirs.bos_token_id is None + elif config.name.startswith("Falcon3"): + if isinstance(ours.bos_id, int): + assert theirs.bos_token_id is None + else: + assert ours.bos_id == theirs.bos_token_id == None else: assert ours.bos_id == theirs.bos_token_id diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index bd46a3564f..a170506c3d 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -12,6 +12,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) | | Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | | Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | +| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) | | FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | | Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | | Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | @@ -233,6 +234,14 @@ tiiuae/falcon-40b tiiuae/falcon-40b-instruct tiiuae/falcon-7b tiiuae/falcon-7b-instruct +tiiuae/Falcon3-1B-Base +tiiuae/Falcon3-1B-Instruct +tiiuae/Falcon3-3B-Base +tiiuae/Falcon3-3B-Instruct +tiiuae/Falcon3-7B-Base +tiiuae/Falcon3-7B-Instruct +tiiuae/Falcon3-10B-Base +tiiuae/Falcon3-10B-Instruct TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T togethercomputer/LLaMA-2-7B-32K From 5670d46700153cd8b5f00907082a0d88366844be Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 23 Dec 2024 14:02:54 -0500 Subject: [PATCH 14/18] ChatML prompt template (#1882) --- litgpt/prompts.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 29cf7a4249..48850efd51 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -290,31 +290,32 @@ def apply(self, prompt: str, **kwargs: str) -> str: return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n" -class Qwen2_5(PromptStyle): - def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." - return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class ChatML(PromptStyle): + def __init__(self, system_message: str): + self.system_message = system_message -class Qwen2_5_Math(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "Please reason step by step, and put your final answer within \\boxed{}." - return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + return f"<|im_start|>system\n{self.system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" -class QwQ(PromptStyle): - def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." - return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class Qwen2_5(ChatML): + def __init__(self): + super().__init__("You are Qwen, created by Alibaba Cloud. You are a helpful assistant.") -class Salamandra(PromptStyle): - def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit. La meva base de coneixement es va actualitzar per última vegada l'agost de 2023." - return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class Qwen2_5_Math(ChatML): + def __init__(self): + super().__init__("Please reason step by step, and put your final answer within \\boxed{}.") +class QwQ(ChatML): + def __init__(self): + super().__init__("You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.") -class SmolLM2(PromptStyle): - def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face" - return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class SmolLM2(ChatML): + def __init__(self): + super().__init__("You are a helpful AI assistant named SmolLM, trained by Hugging Face") + +class Salamandra(ChatML): + def __init__(self): + super().__init__("I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit.") # Maps prompt style names to PromptStyle classes From db308ba1f1aaf7cc2a56753f19d2cb33a4255f9b Mon Sep 17 00:00:00 2001 From: Matthias Seeger Date: Tue, 24 Dec 2024 20:24:35 +0100 Subject: [PATCH 15/18] Small fixes and refactoring (#1861) Co-authored-by: Andrei-Aksionov --- extensions/thunder/unsloth/executor.py | 2 +- litgpt/adapter.py | 4 +- litgpt/adapter_v2.py | 4 +- litgpt/config.py | 53 ++++++++++------- litgpt/generate/base.py | 2 +- litgpt/lora.py | 4 +- litgpt/model.py | 80 +++++++++++++++++--------- tests/test_generate.py | 8 ++- tests/test_generate_adapter.py | 10 +++- 9 files changed, 108 insertions(+), 59 deletions(-) diff --git a/extensions/thunder/unsloth/executor.py b/extensions/thunder/unsloth/executor.py index a0ed54598a..1779daf8ee 100644 --- a/extensions/thunder/unsloth/executor.py +++ b/extensions/thunder/unsloth/executor.py @@ -240,7 +240,7 @@ def unsloth_apply_rope_meta( Q: TensorProxy, cos: TensorProxy, sin: TensorProxy ) -> Tuple[TensorProxy, TensorProxy, TensorProxy, int, int, int]: batch, n_heads, seq_len, head_dim = Q.shape - assert seq_len <= cos.shape[0] + assert seq_len <= cos.shape[-2] BLOCK_SIZE, num_warps = kernels.calculate_settings(head_dim // 2) div, mod = divmod(n_heads, kernels.rope_embedding.ROPE_GROUP_SIZE) n_groups = div + (mod != 0) diff --git a/litgpt/adapter.py b/litgpt/adapter.py index 8523cec814..628217b61c 100644 --- a/litgpt/adapter.py +++ b/litgpt/adapter.py @@ -132,8 +132,8 @@ def __init__(self, config: Config, block_idx: int) -> None: self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None self.block_idx = block_idx self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 1ad3d40b9d..7c94a8d630 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -179,8 +179,8 @@ def __init__(self, config: Config, block_idx: int) -> None: self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None self.block_idx = block_idx self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config diff --git a/litgpt/config.py b/litgpt/config.py index af4098c2d3..a4a70c8238 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -15,23 +15,23 @@ class Config: name: str = "" hf_config: dict = field(default_factory=dict) - scale_embeddings: bool = False - attention_scores_scalar: Optional[int] = None + # General size parameters block_size: int = 4096 - sliding_window_size: Optional[int] = None - sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None + n_layer: int = 16 + n_embd: int = 4096 vocab_size: int = 50254 padding_multiple: int = 512 padded_vocab_size: Optional[int] = None - n_layer: int = 16 + # Transformer block (structure, normalizations) + norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" + norm_eps: float = 1e-5 + post_attention_norm: bool = False + post_mlp_norm: bool = False + parallel_residual: bool = True + shared_attention_norm: bool = False + # Transformer block (self-attention) n_head: int = 32 head_size: Optional[int] = None - n_embd: int = 4096 - rotary_percentage: float = 0.25 - parallel_residual: bool = True - bias: bool = True - lm_head_bias: bool = False - attn_bias: bool = False # to use multi-head attention (MHA), set this to `n_head` (default) # to use multi-query attention (MQA), set this to 1 # to use grouped-query attention (GQA), set this to a value in between @@ -53,20 +53,29 @@ class Config: # # credit https://arxiv.org/pdf/2305.13245.pdf n_query_groups: Optional[int] = None - shared_attention_norm: bool = False - norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" - post_attention_norm: bool = False - post_mlp_norm: bool = False - norm_eps: float = 1e-5 - mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" - gelu_approximate: str = "none" - intermediate_size: Optional[int] = None - rope_condense_ratio: int = 1 + attn_bias: bool = False + attention_scores_scalar: Optional[int] = None + sliding_window_size: Optional[int] = None + sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None + # if `attention_logit_softcapping` is used, cannot use optimized + # `torch.nn.functional.scaled_dot_product_attention` (which implements + # Flash attention), may result in higher memory and runtime footprint. + attention_logit_softcapping: Optional[float] = None + # Rotary position embedding (RoPE) rope_base: int = 10000 + rotary_percentage: float = 0.25 + rope_condense_ratio: int = 1 rope_adjustments: Optional[dict] = None + # Transformer block (MLP) + intermediate_size: Optional[int] = None + bias: bool = True + mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" + gelu_approximate: str = "none" n_expert: int = 0 n_expert_per_token: int = 0 - attention_logit_softcapping: Optional[float] = None + # GPT before/after blocks + scale_embeddings: bool = False + lm_head_bias: bool = False final_logit_softcapping: Optional[float] = None def __post_init__(self): @@ -99,7 +108,7 @@ def __post_init__(self): self.rope_n_elem = int(self.rotary_percentage * self.head_size) if self.sliding_window_size is not None: - self.sliding_window_layer_placing = ( + self.sliding_window_layer_stride = ( 1 if (self.sliding_window_layer_placing is None or self.sliding_window_layer_placing == "all") else 2 ) diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index d349502489..866947beea 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -230,7 +230,7 @@ def batched_generate_fn( Args: model: The model to use. prompts: A 2D tensor of shape [batch_size, prompt_length]. - max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens. + max_returned_tokens: The maximum number of tokens to return, including the prompt tokens. sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch. stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens. include_prompt: Whether to output the prompt tokens. diff --git a/litgpt/lora.py b/litgpt/lora.py index 18a472337b..db48175eac 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -628,8 +628,8 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config diff --git a/litgpt/model.py b/litgpt/model.py index 17b3b4ab04..643ba59a71 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -72,11 +72,30 @@ def _init_weights(self, module: nn.Module) -> None: torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + idx (torch.Tensor): Input token indices, shape `(B, T)` + input_pos (torch.Tensor, optional): Contains input positions, + either with shape `(T,)` or `(B, T)`, if provided. This is used + for generative inference, where a KV cache is required. By + default, this assumes `input_dim == arange(T)` with all inputs + up to `T` provided upfront. + + Returns: + torch.Tensor: Output (logits), shape `(B, T, config.padded_vocab_size)` + """ + if idx.dim() != 2: + raise ValueError(f"idx must have 2 dimensions, idx.shape = {idx.shape}") T = idx.size(1) if self.max_seq_length < T: raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") if input_pos is not None: # use the kv cache + if input_pos.dim() > 2: + # otherwise, things go wrong in `apply_rope` + raise ValueError(f"input_pos must have 1 or 2 dimensions, input_pos.shape = {input_pos.shape}") + if input_pos.shape[-1] != T: + raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1], must be the same") cos = batched_index_select(self.cos, 0, input_pos) sin = batched_index_select(self.sin, 0, input_pos) if self.mask_cache is None: @@ -87,20 +106,22 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) - # we get if input_pos has a batch dimension mask = mask.squeeze(1) else: - cos = self.cos[:T] - sin = self.sin[:T] - mask = None + # unsqueeze to have a batch dimension + cos = self.cos[:T].unsqueeze(0) + sin = self.sin[:T].unsqueeze(0) + # `cos`, `sin` have shape (1, T, config.rope_n_elem) + mask = None # defaults to causal mask - x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd) if self.config.scale_embeddings: x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype) for block in self.transformer.h: x = block(x, cos, sin, mask, input_pos) x = self.transformer.ln_f(x) - x = self.lm_head(x) # (b, t, vocab_size) + x = self.lm_head(x) # (B, T, padded_vocab_size) if self.config.final_logit_softcapping is not None: - x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping + x = do_softcapping(x, self.config.final_logit_softcapping) return x @classmethod @@ -122,10 +143,8 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso elif num_params_present == 4: # These parameters should always be used together so that we don't interfere with standard rope extra_config = { - "original_max_seq_len": self.config.rope_adjustments["original_max_seq_len"], - "factor": self.config.rope_adjustments["factor"], - "low_freq_factor": self.config.rope_adjustments["low_freq_factor"], - "high_freq_factor": self.config.rope_adjustments["high_freq_factor"], + name: self.config.rope_adjustments[name] + for name in adjusted_params_required } else: # Some but not all parameters are specified; raise an error @@ -231,12 +250,13 @@ def forward( attention_output = self.post_attention_norm(attention_output) if self.config.parallel_residual: - x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x) - x = self.mlp(x_normed) + attention_output + x + if not self.config.shared_attention_norm: + x_normed = self.norm_2(x) + x = attention_output + x else: x = attention_output + x - x = self.post_mlp_norm(self.mlp(self.norm_2(x))) + x - return x + x_normed = self.norm_2(x) + return self.post_mlp_norm(self.mlp(x_normed)) + x class CausalSelfAttention(nn.Module): @@ -251,8 +271,8 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config @@ -275,15 +295,17 @@ def forward( qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - # split batched computation into three + # split batched computation into three: + # q: (B, n_query_groups, q_per_kv, T, hs) + # k, v: (B, n_query_groups, 1, T, hs) q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) # maybe repeat k and v if for the non multi-head attention cases # training: flash attention requires it # inference: multi-query would require a full kv cache so avoid it to limit its memory usage if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): - k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) - v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + k = k.expand(*q.shape) + v = v.expand(*q.shape) q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) @@ -331,11 +353,8 @@ def scaled_dot_product_attention( # with softcapping we cannot use SDPA if self.config.attention_logit_softcapping is not None: - scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) scores = q @ k.mT * scale - scores = ( - torch.tanh(scores / self.config.attention_logit_softcapping) * self.config.attention_logit_softcapping - ) + scores = do_softcapping(scores, self.config.attention_logit_softcapping) if mask is None: mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) @@ -496,10 +515,11 @@ def batched_index_select(t, dim, idx): res = torch.index_select(t, dim, idx.reshape(-1)) # flat index # split out single batch idx res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :]) - # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors - dims = [dim] + list(range(res.dim())) - del dims[dim + 1] - res = res.permute(dims) + if dim > 0: + # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors + dims = [dim] + list(range(res.dim())) + del dims[dim + 1] + res = res.permute(dims) # unflatten batch dims res = res.view(*batch_shape, *res.shape[1:]) return res @@ -556,6 +576,8 @@ def batched_index_copy_(t, dim, idx, val): def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + # x: (B, nh, T, hs) + # sin, cos: (B, T, hs) or (1, T, hs) head_size = x.size(-1) x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) @@ -571,6 +593,10 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T return roped.to(dtype=x.dtype) +def do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor: + return torch.tanh(x / thresh) * thresh + + class KVCache(nn.Module): def __init__( self, diff --git a/tests/test_generate.py b/tests/test_generate.py index 6fc561b945..592f2c3acf 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -93,7 +93,13 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): pattern = rf".*^{re.escape(expected_output.strip())}$.*" assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE) - assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue() + err_value = err.getvalue() + expected_parts = [ + "'padded_vocab_size': 512", + "'n_layer': 2", + "'n_head': 4", + ] + assert all(part in err_value for part in expected_parts) def test_cli(): diff --git a/tests/test_generate_adapter.py b/tests/test_generate_adapter.py index 6e57ff0c5e..a40672d03e 100644 --- a/tests/test_generate_adapter.py +++ b/tests/test_generate_adapter.py @@ -55,7 +55,15 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like): pattern = rf".*^{re.escape(expected_output.strip())}$.*" assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE) - assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'head_size': 2, 'n_embd': 8" in err.getvalue() + err_value = err.getvalue() + expected_parts = [ + "'padded_vocab_size': 512", + "'n_layer': 2", + "'n_head': 4", + "'head_size': 2", + "'n_embd': 8", + ] + assert all(part in err_value for part in expected_parts) @pytest.mark.parametrize("version", ("", "_v2")) From fabf765b62ebf03f36ab50a877a6ebaf24f64a86 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Thu, 26 Dec 2024 21:12:36 +0300 Subject: [PATCH 16/18] Drop interleave placement in QKV matrix (#1013) --- litgpt/adapter.py | 2 +- litgpt/adapter_v2.py | 16 +- litgpt/generate/tp.py | 13 +- litgpt/lora.py | 63 ++-- litgpt/model.py | 132 +++++--- litgpt/scripts/convert_hf_checkpoint.py | 374 +++++++++++------------ litgpt/scripts/convert_lit_checkpoint.py | 306 +++++++++---------- tests/test_adapter.py | 36 ++- tests/test_adapter_v2.py | 63 ++-- tests/test_convert_hf_checkpoint.py | 75 ++--- tests/test_convert_lit_checkpoint.py | 80 ++--- tests/test_generate_sequentially.py | 34 +-- tests/test_lora.py | 95 ++++-- tests/test_model.py | 34 ++- 14 files changed, 725 insertions(+), 598 deletions(-) diff --git a/litgpt/adapter.py b/litgpt/adapter.py index 628217b61c..bef77ece1b 100644 --- a/litgpt/adapter.py +++ b/litgpt/adapter.py @@ -151,7 +151,7 @@ def scaled_dot_product_attention( ak, av = self.adapter_kv_cache else: prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) - aqkv = self.attn(prefix) + aqkv = self.qkv(prefix) q_per_kv = self.config.n_head // self.config.n_query_groups aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) aqkv = aqkv.permute(0, 2, 3, 1, 4) diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 7c94a8d630..9b975260f0 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -21,6 +21,7 @@ from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention from litgpt.adapter import Config as BaseConfig from litgpt.model import KVCache +from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights @@ -163,7 +164,7 @@ def __init__(self, config: Config, block_idx: int) -> None: nn.Module.__init__(self) shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) + self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) @@ -186,10 +187,10 @@ def __init__(self, config: Config, block_idx: int) -> None: self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: - """For compatibility with base checkpoints.""" + """For compatibility with base and/or legacy checkpoints.""" mapping = { - "attn.weight": "attn.linear.weight", - "attn.bias": "attn.linear.bias", + "qkv.weight": "qkv.linear.weight", + "qkv.bias": "qkv.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } @@ -197,6 +198,13 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa # For compatibility with older checkpoints if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: state_dict[key] = state_dict[key].permute(0, 2, 1, 3) + + for attr in ("weight", "bias"): + legacy_key = f"{prefix}attn.linear.{attr}" + current_key = f"{prefix}qkv.linear.{attr}" + if legacy_key in state_dict: + state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/litgpt/generate/tp.py b/litgpt/generate/tp.py index c76d4f27c9..7b45ffd014 100644 --- a/litgpt/generate/tp.py +++ b/litgpt/generate/tp.py @@ -3,31 +3,30 @@ import logging import sys import time +import warnings from functools import partial from pathlib import Path from pprint import pprint from typing import Literal, Optional, Union -import warnings import lightning as L -from lightning_utilities.core.imports import RequirementCache import torch import torch._dynamo.config import torch._inductor.config from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.utilities import rank_zero_only +from lightning_utilities.core.imports import RequirementCache import litgpt.generate.base as generate_base -from litgpt.model import GPT from litgpt.config import Config -from litgpt.tokenizer import Tokenizer -from litgpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE +from litgpt.model import GPT, CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style +from litgpt.tokenizer import Tokenizer from litgpt.utils import ( check_nvlink_connectivity, check_valid_checkpoint_dir, extend_checkpoint_dir, - get_default_supported_precision + get_default_supported_precision, ) @@ -71,7 +70,7 @@ def tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMA def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None: - tensor_parallel_linear(fabric, attn.attn, "colwise") + tensor_parallel_linear(fabric, attn.qkv, "colwise") tensor_parallel_linear(fabric, attn.proj, "rowwise") attn.register_forward_hook(partial(all_reduce_output, fabric.world_size)) diff --git a/litgpt/lora.py b/litgpt/lora.py index db48175eac..beca761c48 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -58,6 +58,7 @@ from litgpt.model import Block as BaseBlock from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention from litgpt.model import KVCache +from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights @@ -267,18 +268,14 @@ def lora_ind(self) -> torch.Tensor: # Indices are needed to properly pad weight updates with zeros. if not hasattr(self, "_lora_ind"): enable_q, enable_k, enable_v = self.enable_lora - qkv_group_size = self.n_head // self.n_query_groups + 2 - candidate_indices = range(self.linear.out_features) + kv_embd_size = self.linear.in_features // (self.n_head // self.n_query_groups) lora_ind = [] if enable_q: - q_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size < qkv_group_size - 2] - lora_ind.extend(q_ind) + lora_ind.extend(range(0, self.linear.in_features)) if enable_k: - k_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 2] - lora_ind.extend(k_ind) + lora_ind.extend(range(self.linear.in_features, self.linear.in_features + kv_embd_size)) if enable_v: - v_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 1] - lora_ind.extend(v_ind) + lora_ind.extend(range(self.linear.in_features + kv_embd_size, self.linear.out_features)) self.register_buffer( "_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False ) @@ -298,27 +295,6 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: ________________________________________ | query | key | value | ---------------------------------------- - For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped - queries are adjacent to their associated key and value weights. - For example, suppose we have n_head = 12 with 3 query groups. - Then along the embedding dimension the interleaved weights would look like - - [Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V], - - where each Q, K, and V has size head_size. - - In this case, the previously-described weight update applies separately to each - individual block, so the update will take the form - - [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...], - [.............................................................................], - [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]] - ↑ ↑ ↑ ↑ ↑ ↑ - ________________________________________________________________________________ - | q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ... - -------------------------------------------------------------------------------- - Note that in the above diagram, the size of each q block will equal q_per_kv - times the size of each k and v block. Args: x: tensor with weights update that will be padded with zeros if necessary @@ -391,7 +367,9 @@ def get_lora_AB(self) -> torch.Tensor: lora = self.conv1d( self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128) self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1) - ).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) + ).squeeze( + 0 + ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) return self.zero_pad(lora.T * self.scaling).T # (256, 128) after zero_pad (384, 128) def merge(self) -> None: @@ -430,7 +408,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: after_B = self.conv1d( after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1) - ).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) + ).transpose( + -2, -1 + ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384) return pretrained + lora @@ -602,7 +582,7 @@ def __init__(self, config: Config, block_idx: int) -> None: nn.Module.__init__(self) shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = LoRAQKVLinear( + self.qkv = LoRAQKVLinear( in_features=config.n_embd, out_features=shape, r=config.lora_r, @@ -628,21 +608,28 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_stride == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: - """For compatibility with base checkpoints.""" + """For compatibility with base and/or legacy checkpoints.""" mapping = { - "attn.weight": "attn.linear.weight", - "attn.bias": "attn.linear.bias", + "qkv.weight": "qkv.linear.weight", + "qkv.bias": "qkv.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + + for attr in ("weight", "bias"): + legacy_key = f"{prefix}attn.linear.{attr}" + current_key = f"{prefix}qkv.linear.{attr}" + if legacy_key in state_dict: + state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) @@ -758,4 +745,4 @@ def merge_lora_weights(model: GPT) -> None: """Merge LoRA weights into the full-rank weights to speed up inference.""" for module in model.modules(): if isinstance(module, LoRALinear): - module.merge() \ No newline at end of file + module.merge() diff --git a/litgpt/model.py b/litgpt/model.py index 643ba59a71..cbdf2a4bdd 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -7,13 +7,14 @@ """ import math -from typing import Any, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn from typing_extensions import Self from litgpt.config import Config +from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble class GPT(nn.Module): @@ -44,8 +45,10 @@ def max_seq_length(self, value: int) -> None: This allows setting a smaller number to avoid allocating unused memory """ if value > self.config.block_size: - raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}." - " This is likely because the input text exceeds the supported context length of this model.") + raise ValueError( + f"Cannot attend to {value}, block size is only {self.config.block_size}." + " This is likely because the input text exceeds the supported context length of this model." + ) self._max_seq_length = value if not hasattr(self, "cos"): # first call @@ -148,7 +151,9 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso } else: # Some but not all parameters are specified; raise an error - missing_params = [param for param, present in zip(adjusted_params_required, params_present) if not present] + missing_params = [ + param for param, present in zip(adjusted_params_required, params_present) if not present + ] raise ValueError( f"The following adjusted RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. " "All adjusted RoPE parameters must be specified together." @@ -180,7 +185,11 @@ def set_kv_cache( # initialize the kv cache for all blocks for block in self.transformer.h: block.attn.kv_cache = block.attn.build_kv_cache( - batch_size, max_seq_length, rope_cache_length, device, dtype, + batch_size, + max_seq_length, + rope_cache_length, + device, + dtype, ) if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: @@ -262,17 +271,20 @@ def forward( class CausalSelfAttention(nn.Module): def __init__(self, config: Config, block_idx: int) -> None: super().__init__() - shape = (config.n_head + 2 * config.n_query_groups) * config.head_size - # key, query, value projections for all heads, but in a batch - self.attn = nn.Linear(config.n_embd, shape, bias=config.bias or config.attn_bias) + # key, query and value projections for all heads, but in a batch + self.qkv = nn.Linear( + config.n_embd, + (config.n_head + 2 * config.n_query_groups) * config.head_size, # support for grouped/multi queries + bias=config.bias or config.attn_bias, + ) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_stride == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config @@ -285,42 +297,60 @@ def forward( mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - - qkv = self.attn(x) - - # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) - q_per_kv = self.config.n_head // self.config.n_query_groups - total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value - qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - - # split batched computation into three: - # q: (B, n_query_groups, q_per_kv, T, hs) - # k, v: (B, n_query_groups, 1, T, hs) - q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - - # maybe repeat k and v if for the non multi-head attention cases - # training: flash attention requires it - # inference: multi-query would require a full kv cache so avoid it to limit its memory usage - if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): - k = k.expand(*q.shape) - v = v.expand(*q.shape) - - q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - + # Notation: + # - B | batch size + # - T | time-step (sequence length) + # - C | model's embeddings size (n_embd) + # - C* | attentions's embeddings size + # - nh_(q,k,v) | number of heads for query, key and value + # - hs | head size + + B, T, C = x.size() + + # Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value` + # instead of individually multiplying the input `x` with the respective weight matrices. + qkv = self.qkv(x) # (B, T, 3xC*) + + # Define query, key and value sizes. + # If grouped/multi query is enabled, these sizes are not equal (see the diagram in `lit_gpt/config.py::Config`). + query_size = self.config.n_head * self.config.head_size + key_size = value_size = self.config.n_query_groups * self.config.head_size + # Split qkv into query, key and value matrices. + q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) + + # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the + # embedding size (C) into num_heads (nh) and head_size (hs). + q = q.view(B, T, self.config.n_head, self.config.head_size) # (B, T, nh_q, hs) + k = k.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_k, hs) + v = v.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_v, hs) + + # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are + # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector + # of size `hs`. + q = q.transpose(1, 2) # (B, nh_q, T, hs) + k = k.transpose(1, 2) # (B, nh_k, T, hs) + v = v.transpose(1, 2) # (B, nh_v, T, hs) + + # Unlike standard positional embeddings rotary embeddings must be applied at every layer. q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) + q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) # (B, nh_q, T, hs) + k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) # (B, nh_k, T, hs) + # Apply kv-cache during inference. if input_pos is not None: if not isinstance(self.kv_cache, KVCache): raise TypeError("You need to call `gpt.set_kv_cache()`") k, v = self.kv_cache(input_pos, k, v) + # Grouped queries: balance the number of heads across all three matrices. + # NOTE: flash attention requires it in training mode. + # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. + if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): + q_per_kv = self.config.n_head // self.config.n_query_groups + k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + if self.apply_sliding_window_attention: """ Global Window Sliding window Sliding window @@ -339,12 +369,16 @@ def forward( sliding_window_bias.masked_fill_(sliding_window_bias.bool(), float("-inf")) mask += sliding_window_bias + # Efficient attention using Flash Attention CUDA kernels. + # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. + # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) y = self.scaled_dot_product_attention(q, k, v, mask) - y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side + # Re-assemble all head outputs side by side. + y = y.reshape(B, T, self.config.head_size * self.config.n_head) - # output projection - return self.proj(y) + # Output projection. + return self.proj(y) # (B, T, C) def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None @@ -375,8 +409,7 @@ def build_kv_cache( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> "KVCache": - heads = 1 if self.config.n_query_groups == 1 else self.config.n_head - v_shape = (batch_size, heads, max_seq_length, self.config.head_size) + v_shape = (batch_size, self.config.n_query_groups, max_seq_length, self.config.head_size) if rope_cache_length is None: if self.config.rotary_percentage != 1.0: raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") @@ -384,12 +417,23 @@ def build_kv_cache( else: k_shape = ( batch_size, - heads, + self.config.n_query_groups, max_seq_length, rope_cache_length + self.config.head_size - self.config.rope_n_elem, ) return KVCache(k_shape, v_shape, device=device, dtype=dtype) + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with legacy checkpoints.""" + + for attr in ("weight", "bias"): + legacy_key = f"{prefix}attn.{attr}" + current_key = f"{prefix}qkv.{attr}" + if legacy_key in state_dict: + state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) + + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + class GptNeoxMLP(nn.Module): def __init__(self, config: Config) -> None: diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 6125840ed9..fbcfa871a6 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -2,37 +2,38 @@ import gc import json +import os +import re from collections import defaultdict from functools import partial -import os from pathlib import Path from pprint import pprint from typing import Dict, List, Optional, Tuple, Union -from tqdm import tqdm import torch from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor +from tqdm import tqdm from litgpt.config import Config from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load, save_config def copy_weights_gpt_neox( + config: Config, state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False - + debug_mode: Optional[bool] = False, ) -> None: weight_map = { "gpt_neox.embed_in.weight": "transformer.wte.weight", "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", - "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias", - "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", + "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.qkv.bias", + "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight", "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, @@ -52,16 +53,16 @@ def copy_weights_gpt_neox( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights)) - for name, param in hf_weights.items(): - if "gpt_neox.layers" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template] + if to_name is None: + continue + to_name = to_name.format(layer_idx) + param = load_param(param, from_name, dtype, verbose=debug_mode) + if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): + # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -71,18 +72,18 @@ def copy_weights_gpt_neox( def copy_weights_falcon( - model_name: str, + config: Config, state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False + debug_mode: Optional[bool] = False, ) -> None: weight_map = { "transformer.word_embeddings.weight": "transformer.wte.weight", - "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", + "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight", "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", @@ -91,14 +92,14 @@ def copy_weights_falcon( "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size - if "7b" in model_name: + if "7b" in config.name: weight_map.update( { "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", } ) - elif "40b" in model_name or "180B" in model_name: + elif "40b" in config.name or "180B" in config.name: weight_map.update( { "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", @@ -113,16 +114,17 @@ def copy_weights_falcon( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights)) - for name, param in hf_weights.items(): - if "transformer.h" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name].format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template].format(layer_idx) + param = load_param(param, from_name, dtype, verbose=debug_mode) + if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): + # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param + if progress_per_file is not None: pbar.update(progress_per_file) @@ -136,19 +138,19 @@ def copy_weights_hf_llama( dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False + debug_mode: Optional[bool] = False, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", - "model.layers.{}.input_layernorm.weight": "transformer.h.{l}.norm_1.weight", - "model.layers.{}.input_layernorm.bias": "transformer.h.{l}.norm_1.bias", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, - "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{l}.attn.proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.self_attn.rotary_emb.inv_freq": None, - "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{l}.norm_2.weight", - "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{l}.norm_2.bias", + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", "model.norm.weight": "transformer.ln_f.weight", "model.norm.bias": "transformer.ln_f.bias", "lm_head.weight": "lm_head.weight", @@ -156,18 +158,18 @@ def copy_weights_hf_llama( if config.mlp_class_name == "LLaMAMoE": weight_map.update( { - "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{l}.mlp.gate.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{l}.mlp.experts.{e}.fc_1.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{l}.mlp.experts.{e}.fc_2.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight", + "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{}.mlp.gate.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{}.mlp.experts.{}.proj.weight", } ) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): weight_map.update( { - "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight", - "model.layers.{}.mlp.up_proj.weight": "transformer.h.{l}.mlp.fc_2.weight", - "model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.proj.weight", + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", } ) else: @@ -176,26 +178,17 @@ def copy_weights_hf_llama( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) - for name, param in hf_weights.items(): - if "model.layers" in name: - from_name, l = layer_template(name, 2) - e = None - if "block_sparse_moe.experts" in name: - from_name, e = layer_template(from_name, 5) - qkv = qkv_weights.setdefault(l, [None, None, None]) - if "q_proj" in name: - qkv[0] = param - elif "k_proj" in name: - qkv[1] = param - elif "v_proj" in name: - qkv[2] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l=l, e=e) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -206,28 +199,24 @@ def copy_weights_hf_llama( if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] - # convert separate q, k, v matrices into an interleaved qkv - for i, (q, k, v) in list(qkv_weights.items()): - if q is None or k is None or v is None: - # split across different .bin files - continue - q = load_param(q, f"layer {i} q", dtype, verbose=debug_mode) - k = load_param(k, f"layer {i} k", dtype, verbose=debug_mode) - v = load_param(v, f"layer {i} v", dtype, verbose=debug_mode) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv - del qkv_weights[i] - if progress_per_file is not None: - pbar.update(progress_per_file) + for i in list(qkv_weights): + for weight_type in list(qkv_weights[i]): + qkv = qkv_weights[i][weight_type] + if len(qkv) != 3: + # qkv is splitted across different .bin files + continue + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv + del qkv_weights[i][weight_type] + + if progress_per_file is not None: + pbar.update(progress_per_file) def copy_weights_gemma_2( - config: Config, qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], @@ -235,7 +224,7 @@ def copy_weights_gemma_2( dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False + debug_mode: Optional[bool] = False, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", @@ -257,20 +246,17 @@ def copy_weights_gemma_2( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) - for name, param in hf_weights.items(): - if "model.layers" in name: - from_name, l_idx = layer_template(name, 2) - qkv = qkv_weights.setdefault(l_idx, defaultdict(dict)) - if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): - weight_name, weight_type = from_name.split(".")[-2:] - qkv[weight_type][weight_name] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l_idx) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype) + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -281,24 +267,19 @@ def copy_weights_gemma_2( if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] - # convert separate q, k, v matrices into an interleaved qkv for i in list(qkv_weights): for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: - # split across different .bin files + # qkv is splitted across different .bin files continue - q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype) - k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype) - v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] + if progress_per_file is not None: pbar.update(progress_per_file) @@ -312,7 +293,7 @@ def copy_weights_phi( dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False + debug_mode: Optional[bool] = False, ) -> None: if any(layer_name.startswith(("layers.", "transformer.")) for layer_name in hf_weights): raise ValueError( @@ -344,7 +325,7 @@ def copy_weights_phi( if config.name.startswith("Phi-3"): weight_map.update( { - "model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.attn.weight", + "model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.qkv.weight", "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", @@ -355,35 +336,27 @@ def copy_weights_phi( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) - for name, param in hf_weights.items(): - if name.startswith("model.layers."): - from_name, l = layer_template(name, 2) - qkv = qkv_weights.setdefault(l, defaultdict(dict)) - if "qkv_proj" in from_name: - weight = load_param(param, f"layer {l} qkv", dtype) - weight = qkv_reassemble(weight, config) - to_name = weight_map[from_name].format(l) - state_dict[to_name] = weight - continue - if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): - weight_name, weight_type = from_name.split(".")[-2:] - qkv[weight_type][weight_name] = param - elif from_name.endswith("gate_up_proj.weight"): - weight = load_param(param, f"layer {l} gate_up_proj", dtype) - fc_1, fc_2 = weight.chunk(2, dim=0) - state_dict[f"transformer.h.{l}.mlp.fc_1.weight"] = fc_1 - state_dict[f"transformer.h.{l}.mlp.fc_2.weight"] = fc_2 - continue - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, layer_idx = layer_template(from_name) + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(layer_idx, defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + elif from_name.endswith("gate_up_proj.weight"): + weight = load_param(param, f"layer {layer_idx} gate_up_proj", dtype, verbose=debug_mode) + fc_1, fc_2 = weight.chunk(2, dim=0) + state_dict[f"transformer.h.{layer_idx}.mlp.fc_1.weight"] = fc_1 + state_dict[f"transformer.h.{layer_idx}.mlp.fc_2.weight"] = fc_2 + continue + to_name = weight_map[name_template] + if to_name is None: + continue + to_name = to_name.format(layer_idx) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param + if progress_per_file is not None: pbar.update(progress_per_file) @@ -391,19 +364,15 @@ def copy_weights_phi( for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: - # split across different .bin files + # qkv is splitted across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] + if progress_per_file is not None: pbar.update(progress_per_file) @@ -439,20 +408,17 @@ def copy_weights_qwen_2_5( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) - for name, param in hf_weights.items(): - if "model.layers" in name: - from_name, l = layer_template(name, 2) - qkv = qkv_weights.setdefault(l, defaultdict(dict)) - if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): - weight_name, weight_type = from_name.split(".")[-2:] - qkv[weight_type][weight_name] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -467,49 +433,52 @@ def copy_weights_qwen_2_5( for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: - # split across different .bin files + # qkv is splitted across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] + if progress_per_file is not None: pbar.update(progress_per_file) -def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: + +def qkv_reassemble( + param: Union[torch.Tensor, NotYetLoadedTensor], config: Config +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Reassemble from a normal to an interleaved placement in a QKV matrix. - [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] + [Q, K, V, Q, K, V, ...] --> [Q, Q, ..., K, K, ..., V, V, ...] """ - q, k, v = param.split( - ( - config.n_head * config.head_size, - config.n_query_groups * config.head_size, - config.n_query_groups * config.head_size, - ) - ) - qs = q.split(config.n_head // config.n_query_groups * config.head_size) - ks = k.split(config.head_size) - vs = v.split(config.head_size) - interleaved = [t for group in zip(qs, ks, vs) for t in group] - return torch.cat(interleaved) - - -def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: - split = layer_name.split(".") - number = int(split[idx]) - split[idx] = "{}" - from_name = ".".join(split) - return from_name, number - - -def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose=False) -> torch.Tensor: + q_per_kv = config.n_head // config.n_query_groups + qs = [] + ks = [] + vs = [] + for chunk in torch.chunk(param, config.n_query_groups): + split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size]) + qs.append(split[0]) + ks.append(split[1]) + vs.append(split[2]) + q = torch.cat(qs) + k = torch.cat(ks) + v = torch.cat(vs) + return torch.cat((q, k, v)) + + + +def layer_template(layer_name: str, num_matches: int = 1) -> Tuple[str, int]: + pattern = r"\.(\d+)\." + if not (search_res := re.findall(pattern, layer_name)): + return layer_name, -1 + layer_name_template = re.sub(pattern, ".{}.", layer_name, count=num_matches) + return layer_name_template, *(int(x) for x in search_res[:num_matches]) + + +def load_param( + param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose: bool =False +) -> torch.Tensor: if hasattr(param, "_load_tensor"): # support tensors loaded via `lazy_load()` if verbose: @@ -522,13 +491,14 @@ def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: return param + @torch.inference_mode() def convert_hf_checkpoint( checkpoint_dir: Path, *, model_name: Optional[str] = None, dtype: Optional[str] = None, - debug_mode: Optional[bool] = False + debug_mode: Optional[bool] = False, ) -> None: """ Convert a Hugging Face Transformers checkpoint into a LitGPT compatible checkpoint. @@ -554,10 +524,10 @@ def convert_hf_checkpoint( save_config(config, checkpoint_dir) if "falcon" in model_name: - copy_fn = partial(copy_weights_falcon, model_name) + copy_fn = partial(copy_weights_falcon, config) elif model_name.lower().startswith("gemma-2"): qkv_weights = {} - copy_fn = partial(copy_weights_gemma_2, config, qkv_weights) + copy_fn = partial(copy_weights_gemma_2, qkv_weights) elif model_name.lower().startswith("phi"): # holder to reconstitute the split q, k, v qkv_weights = {} @@ -571,7 +541,7 @@ def convert_hf_checkpoint( qkv_weights = {} copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) else: - copy_fn = copy_weights_gpt_neox + copy_fn = partial(copy_weights_gpt_neox, config) # initialize a new empty state dict to hold our new weights sd = {} @@ -604,14 +574,26 @@ def convert_hf_checkpoint( total_size = max(1, sum(os.path.getsize(bin_file) for bin_file in bin_files)) total_progress = 100 - with tqdm(total=total_progress, desc="Initializing", bar_format="{desc}{percentage:3.0f}%|{bar}| {elapsed}<{remaining}, {rate_fmt}") as pbar: + with tqdm( + total=total_progress, + desc="Initializing", + bar_format="{desc}{percentage:3.0f}%|{bar}| {elapsed}<{remaining}, {rate_fmt}", + ) as pbar: for bin_file in sorted(bin_files): pbar.set_description(f"Loading weights: {bin_file.name}") current_file_size = os.path.getsize(bin_file) progress_per_file = (current_file_size / total_size) * total_progress hf_weights = lazy_load(bin_file) - copy_fn(sd, hf_weights, saver=saver, dtype=dtype, pbar=pbar, progress_per_file=progress_per_file, debug_mode=debug_mode) + copy_fn( + sd, + hf_weights, + saver=saver, + dtype=dtype, + pbar=pbar, + progress_per_file=progress_per_file, + debug_mode=debug_mode, + ) gc.collect() if pbar.n < total_progress: diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index a994b3022a..f276e3ae31 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -5,7 +5,7 @@ from functools import partial from pathlib import Path from pprint import pprint -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Union import torch from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor @@ -16,14 +16,14 @@ def copy_weights_falcon( - model_name: str, + config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "transformer.word_embeddings.weight", - "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", + "transformer.h.{}.attn.qkv.weight": "transformer.h.{}.self_attention.query_key_value.weight", "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", @@ -32,14 +32,14 @@ def copy_weights_falcon( "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size - if "7b" in model_name: + if "7b" in config.name: weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", } ) - elif "40b" in model_name or "180B" in model_name: + elif "40b" in config.name or "180B" in config.name: weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", @@ -51,19 +51,20 @@ def copy_weights_falcon( else: raise NotImplementedError - for name, param in lit_weights.items(): - if "transformer.h" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name].format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + for from_name, param in lit_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template].format(layer_idx) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): + # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_gpt_neox( + config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, @@ -72,8 +73,8 @@ def copy_weights_gpt_neox( "transformer.wte.weight": "gpt_neox.embed_in.weight", "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", - "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", - "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", + "transformer.h.{}.attn.qkv.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", + "transformer.h.{}.attn.qkv.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", @@ -87,13 +88,13 @@ def copy_weights_gpt_neox( "lm_head.weight": "embed_out.weight", } - for name, param in lit_weights.items(): - if "transformer.h" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name].format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + for from_name, param in lit_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template].format(layer_idx) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): + # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -108,11 +109,11 @@ def copy_weights_llama( ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", - "transformer.h.{}.norm_1.weight": "model.layers.{l}.input_layernorm.weight", - "transformer.h.{}.norm_1.bias": "model.layers.{l}.input_layernorm.bias", - "transformer.h.{}.attn.proj.weight": "model.layers.{l}.self_attn.o_proj.weight", - "transformer.h.{}.norm_2.weight": "model.layers.{l}.post_attention_layernorm.weight", - "transformer.h.{}.norm_2.bias": "model.layers.{l}.post_attention_layernorm.bias", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias", "transformer.ln_f.weight": "model.norm.weight", "transformer.ln_f.bias": "model.norm.bias", "lm_head.weight": "lm_head.weight", @@ -120,48 +121,46 @@ def copy_weights_llama( if config.mlp_class_name == "LLaMAMoE": weight_map.update( { - "transformer.h.{}.mlp.gate.weight": "model.layers.{l}.block_sparse_moe.gate.weight", - "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight", - "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight", - "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight", + "transformer.h.{}.mlp.gate.weight": "model.layers.{}.block_sparse_moe.gate.weight", + "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{}.block_sparse_moe.experts.{}.w1.weight", + "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{}.block_sparse_moe.experts.{}.w3.weight", + "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{}.block_sparse_moe.experts.{}.w2.weight", } ) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): weight_map.update( { - "transformer.h.{}.mlp.fc_1.weight": "model.layers.{l}.mlp.gate_proj.weight", - "transformer.h.{}.mlp.fc_2.weight": "model.layers.{l}.mlp.up_proj.weight", - "transformer.h.{}.mlp.proj.weight": "model.layers.{l}.mlp.down_proj.weight", + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", + "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", } ) else: raise NotImplementedError - for name, param in lit_weights.items(): - if name == "lm_head.weight" and untie_weights: + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: continue - if name.endswith(".attn.attn.weight"): - from_name, l = layer_template(name, 2) - q = "model.layers.{}.self_attn.q_proj.weight".format(l) - k = "model.layers.{}.self_attn.k_proj.weight".format(l) - v = "model.layers.{}.self_attn.v_proj.weight".format(l) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith(".attn.qkv.weight"): + to_names = ( + "model.layers.{}.self_attn.q_proj.weight".format(*ids), + "model.layers.{}.self_attn.k_proj.weight".format(*ids), + "model.layers.{}.self_attn.v_proj.weight".format(*ids), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) else: - if "transformer.h" in name: - from_name, l = layer_template(name, 2) - e = None - if "mlp.experts" in name: - from_name, e = layer_template(from_name, 5) - to_name = weight_map[from_name] - to_name = to_name.format(l=l, e=e) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -188,31 +187,29 @@ def copy_weights_gemma_2( "lm_head.weight": "lm_head.weight", } - for name, param in lit_weights.items(): - if name == "lm_head.weight" and untie_weights: + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: continue - if name.endswith(".attn.attn.weight"): - from_name, layer_idx = layer_template(name, 2) - q = "model.layers.{}.self_attn.q_proj.weight".format(layer_idx) - k = "model.layers.{}.self_attn.k_proj.weight".format(layer_idx) - v = "model.layers.{}.self_attn.v_proj.weight".format(layer_idx) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith(".attn.qkv.weight"): + to_names = ( + "model.layers.{}.self_attn.q_proj.weight".format(*ids), + "model.layers.{}.self_attn.k_proj.weight".format(*ids), + "model.layers.{}.self_attn.v_proj.weight".format(*ids), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) else: - if "transformer.h" in name: - from_name, layer_idx = layer_template(name, 2) - e = None - if "mlp.experts" in name: - from_name, e = layer_template(from_name, 5) - to_name = weight_map[from_name] - to_name = to_name.format(layer_idx) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -239,11 +236,10 @@ def copy_weights_phi( "lm_head.weight": "lm_head.weight", "lm_head.bias": "lm_head.bias", } - if config.name.startswith("Phi-3"): weight_map.update( { - "transformer.h.{}.attn.attn.weight": "model.layers.{}.self_attn.qkv_proj.weight", + "transformer.h.{}.attn.qkv.weight": "model.layers.{}.self_attn.qkv_proj.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", @@ -252,51 +248,48 @@ def copy_weights_phi( ) gate_up_proj_weights = defaultdict(dict) - for name, param in lit_weights.items(): - if name.endswith((".attn.attn.weight", ".attn.attn.bias")): - from_name, l_idx = layer_template(name, 2) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) + for from_name, param in lit_weights.items(): + name_template, layer_idx = layer_template(from_name) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): if config.name.startswith("Phi-3"): - qkv_reassembled = torch.concat([qp, kp, vp], dim=0) - to_name = weight_map[from_name].format(l_idx) - if saver is not None: - qkv_reassembled = saver.store_early(qkv_reassembled) - state_dict[to_name] = qkv_reassembled + to_names = (weight_map[name_template].format(layer_idx),) + params = (param,) else: - weight_type = name.split(".")[-1] # weight or bias - q = f"model.layers.{l_idx}.self_attn.q_proj.{weight_type}" - k = f"model.layers.{l_idx}.self_attn.k_proj.{weight_type}" - v = f"model.layers.{l_idx}.self_attn.v_proj.{weight_type}" - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param - elif name.endswith((".fc_1.weight", ".fc_2.weight")): - from_name, l_idx = layer_template(name, 2) - weight = load_param(param, name, None) - weight_name = name.split(".")[-2] - gate_up_proj_weights[l_idx][weight_name] = weight + weight_type = from_name.split(".")[-1] # weight or bias + to_names = ( + f"model.layers.{{}}.self_attn.q_proj.{weight_type}".format(layer_idx), + f"model.layers.{{}}.self_attn.k_proj.{weight_type}".format(layer_idx), + f"model.layers.{{}}.self_attn.v_proj.{weight_type}".format(layer_idx), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) + elif from_name.endswith((".fc_1.weight", ".fc_2.weight")): + weight = load_param(param, from_name, None) + weight_name = from_name.split(".")[-2] + gate_up_proj_weights[layer_idx][weight_name] = weight else: - if "transformer.h" in name: - from_name, l_idx = layer_template(name, 2) - to_name = weight_map[from_name] - to_name = to_name.format(l_idx) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(layer_idx),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if config.name.startswith("Phi-3"): - for i in list(gate_up_proj_weights): - fc_1_weight = gate_up_proj_weights[i]["fc_1"] - fc_2_weight = gate_up_proj_weights[i]["fc_2"] + for layer_idx in list(gate_up_proj_weights): + fc_1_weight = gate_up_proj_weights[layer_idx]["fc_1"] + fc_2_weight = gate_up_proj_weights[layer_idx]["fc_2"] weight = torch.concat([fc_1_weight, fc_2_weight], dim=0) - layer_name = f"model.layers.{i}.mlp.gate_up_proj.weight" + layer_name = f"model.layers.{layer_idx}.mlp.gate_up_proj.weight" state_dict[layer_name] = weight - del gate_up_proj_weights[i] + del gate_up_proj_weights[layer_idx] def copy_weights_qwen_2_5( config: Config, @@ -317,50 +310,51 @@ def copy_weights_qwen_2_5( "lm_head.weight": "lm_head.weight", } - for name, param in lit_weights.items(): - if name == "lm_head.weight" and untie_weights: + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: continue - if name.endswith((".attn.attn.weight", ".attn.attn.bias")): - from_name, l_idx = layer_template(name, 2) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) - - weight_type = name.split(".")[-1] # weight or bias - q = f"model.layers.{l_idx}.self_attn.q_proj.{weight_type}" - k = f"model.layers.{l_idx}.self_attn.k_proj.{weight_type}" - v = f"model.layers.{l_idx}.self_attn.v_proj.{weight_type}" - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): + weight_type = from_name.split(".")[-1] # weight or bias + to_names = ( + "model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type), + "model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type), + "model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) else: - if "transformer.h" in name: - from_name, l_idx = layer_template(name, 2) - to_name = weight_map[from_name] - to_name = to_name.format(l_idx) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param -def qkv_split( - param: Union[torch.Tensor, NotYetLoadedTensor], config: Config -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q_per_kv = config.n_head // config.n_query_groups - qs = [] - ks = [] - vs = [] - for chunk in torch.chunk(param, config.n_query_groups): - split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size]) - qs.append(split[0]) - ks.append(split[1]) - vs.append(split[2]) - q = torch.cat(qs) - k = torch.cat(ks) - v = torch.cat(vs) - return q, k, v + +def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: + """Reassemble from a normal to an interleaved placement in a QKV matrix. + [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] + """ + q, k, v = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) + qs = q.split(config.n_head // config.n_query_groups * config.head_size) + ks = k.split(config.head_size) + vs = v.split(config.head_size) + interleaved = [t for group in zip(qs, ks, vs) for t in group] + return torch.cat(interleaved) def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: @@ -382,7 +376,7 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: output_path = output_dir / "model.pth" if "falcon" in config.name: - copy_fn = partial(copy_weights_falcon, config.name) + copy_fn = partial(copy_weights_falcon, config) elif config.name.startswith("Gemma-2"): copy_fn = partial(copy_weights_gemma_2, config) elif config.name.lower().startswith("phi"): @@ -393,7 +387,7 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: untie_weights = "Gemma" in config.name copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights) else: - copy_fn = copy_weights_gpt_neox + copy_fn = partial(copy_weights_gpt_neox, config) # initialize a new empty state dict to hold our new weights sd = {} diff --git a/tests/test_adapter.py b/tests/test_adapter.py index da422f6288..9deb7be1f7 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout +from copy import deepcopy from dataclasses import asdict from io import StringIO from unittest import mock @@ -19,10 +20,11 @@ import litgpt.adapter as gpt_adapter import litgpt.finetune.adapter as module import litgpt.model as gpt -from litgpt.adapter import GPT, Config, adapter_filter +from litgpt.adapter import GPT, CausalSelfAttention, Config, adapter_filter from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -192,7 +194,7 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca "transformer.h.0.norm_1.weight", "transformer.h.0.norm_1.bias", "transformer.h.0.attn.gating_factor", - "transformer.h.0.attn.attn.bias", + "transformer.h.0.attn.qkv.bias", "transformer.h.0.attn.proj.bias", "transformer.h.0.attn.adapter_wte.weight", "transformer.h.0.norm_2.weight", @@ -202,7 +204,7 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca "transformer.h.1.norm_1.weight", "transformer.h.1.norm_1.bias", "transformer.h.1.attn.gating_factor", - "transformer.h.1.attn.attn.bias", + "transformer.h.1.attn.qkv.bias", "transformer.h.1.attn.proj.bias", "transformer.h.1.attn.adapter_wte.weight", "transformer.h.1.norm_2.weight", @@ -214,11 +216,11 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca }, "torch.uint8": { "lm_head.weight", - "transformer.h.0.attn.attn.weight", + "transformer.h.0.attn.qkv.weight", "transformer.h.0.attn.proj.weight", "transformer.h.0.mlp.fc.weight", "transformer.h.0.mlp.proj.weight", - "transformer.h.1.attn.attn.weight", + "transformer.h.1.attn.qkv.weight", "transformer.h.1.attn.proj.weight", "transformer.h.1.mlp.fc.weight", "transformer.h.1.mlp.proj.weight", @@ -345,7 +347,7 @@ def test_against_original_gemma_2(model_name, device, dtype): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_gemma_2(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -355,3 +357,25 @@ def test_against_original_gemma_2(model_name, device, dtype): ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.weight"] = make_qkv_interleaved(state_dict.pop("qkv.weight"), config) + state_dict["attn.bias"] = make_qkv_interleaved(state_dict.pop("qkv.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index aec205155d..ca00a5d641 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout +from copy import deepcopy from io import StringIO from unittest import mock from unittest.mock import Mock @@ -19,11 +20,12 @@ import litgpt.config as config_module import litgpt.finetune.adapter_v2 as module from litgpt.adapter_v2 import GPT as AdapterV2GPT -from litgpt.adapter_v2 import Config, adapter_filter +from litgpt.adapter_v2 import CausalSelfAttention, Config, adapter_filter from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.model import GPT as BaseGPT from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -33,10 +35,10 @@ def test_config_identical(): base_model = BaseGPT.from_name(name) adapter_model = AdapterV2GPT.from_name(name) - assert not hasattr(base_model.transformer.h[2].attn.attn, "adapter_bias") - assert not hasattr(base_model.transformer.h[2].attn.attn, "adapter_scale") - assert hasattr(adapter_model.transformer.h[2].attn.attn, "adapter_bias") - assert hasattr(adapter_model.transformer.h[2].attn.attn, "adapter_scale") + assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_bias") + assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_scale") + assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_bias") + assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_scale") def test_adapter_v2_filter(tmp_path): @@ -56,8 +58,8 @@ def test_adapter_v2_filter(tmp_path): } for layer in range(3): for param in ( - "attn.attn.adapter_bias", - "attn.attn.adapter_scale", + "attn.qkv.adapter_bias", + "attn.qkv.adapter_scale", "attn.proj.adapter_bias", "attn.proj.adapter_scale", "mlp.fc.adapter_bias", @@ -297,7 +299,7 @@ def test_against_original_gemma_2(model_name): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_gemma_2(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = AdapterV2GPT(ours_config).to(device) keys = ours_model.load_state_dict(state_dict, strict=False) assert not keys.unexpected_keys @@ -364,27 +366,27 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "torch.uint8": { "transformer.h.0.mlp.fc.linear.weight", "transformer.h.1.mlp.proj.linear.weight", - "transformer.h.1.attn.attn.linear.weight", + "transformer.h.1.attn.qkv.linear.weight", "transformer.h.0.attn.proj.linear.weight", "lm_head.linear.weight", "transformer.h.1.attn.proj.linear.weight", "transformer.h.0.mlp.proj.linear.weight", - "transformer.h.0.attn.attn.linear.weight", + "transformer.h.0.attn.qkv.linear.weight", "transformer.h.1.mlp.fc.linear.weight", }, "torch.float16": { - "transformer.h.1.attn.attn.adapter_bias", + "transformer.h.1.attn.qkv.adapter_bias", "transformer.h.1.mlp.proj.adapter_bias", - "transformer.h.0.attn.attn.adapter_bias", + "transformer.h.0.attn.qkv.adapter_bias", "transformer.h.0.norm_1.bias", - "transformer.h.0.attn.attn.linear.bias", + "transformer.h.0.attn.qkv.linear.bias", "transformer.h.1.attn.adapter_wte.weight", "transformer.ln_f.weight", "transformer.h.0.mlp.fc.linear.bias", "transformer.h.0.mlp.proj.linear.bias", "transformer.h.1.mlp.fc.linear.bias", "transformer.h.0.attn.proj.adapter_scale", - "transformer.h.0.attn.attn.adapter_scale", + "transformer.h.0.attn.qkv.adapter_scale", "transformer.h.1.norm_2.bias", "transformer.h.1.attn.proj.adapter_scale", "transformer.h.0.norm_2.bias", @@ -406,9 +408,9 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "lm_head.adapter_bias", "transformer.h.1.norm_2.weight", "transformer.h.0.attn.adapter_wte.weight", - "transformer.h.1.attn.attn.adapter_scale", + "transformer.h.1.attn.qkv.adapter_scale", "transformer.h.1.mlp.fc.adapter_scale", - "transformer.h.1.attn.attn.linear.bias", + "transformer.h.1.attn.qkv.linear.bias", "transformer.wte.weight", "transformer.wte.norm.weight", "transformer.wte.norm.bias", @@ -435,20 +437,20 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "transformer.ln_f.bias", "lm_head.adapter_scale", "transformer.h.1.norm_2.weight", - "transformer.h.0.attn.attn.adapter_scale", + "transformer.h.0.attn.qkv.adapter_scale", "transformer.h.0.mlp.proj.adapter_bias", "transformer.h.0.attn.gating_factor", "transformer.h.1.norm_1.bias", "transformer.h.1.mlp.fc.adapter_bias", "transformer.h.1.mlp.proj.adapter_scale", "transformer.h.0.mlp.fc.adapter_scale", - "transformer.h.1.attn.attn.adapter_bias", + "transformer.h.1.attn.qkv.adapter_bias", "transformer.h.0.norm_2.weight", "transformer.h.1.norm_2.bias", "transformer.h.0.norm_1.weight", "transformer.h.0.attn.proj.adapter_scale", "transformer.h.1.mlp.proj.adapter_bias", - "transformer.h.0.attn.attn.adapter_bias", + "transformer.h.0.attn.qkv.adapter_bias", "transformer.h.0.attn.adapter_wte.weight", "transformer.ln_f.weight", "transformer.h.1.attn.gating_factor", @@ -458,10 +460,31 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "transformer.h.0.norm_1.bias", "transformer.h.0.norm_2.bias", "transformer.h.1.norm_1.weight", - "transformer.h.1.attn.attn.adapter_scale", + "transformer.h.1.attn.qkv.adapter_scale", } } logs = stdout.getvalue() assert "of trainable parameters: 552" in logs assert "of non-trainable parameters: 1,808" in logs + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.linear.weight"] = make_qkv_interleaved(state_dict.pop("qkv.linear.weight"), config) + state_dict["attn.linear.bias"] = make_qkv_interleaved(state_dict.pop("qkv.linear.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) diff --git a/tests/test_convert_hf_checkpoint.py b/tests/test_convert_hf_checkpoint.py index 08749e521d..38b41f711d 100644 --- a/tests/test_convert_hf_checkpoint.py +++ b/tests/test_convert_hf_checkpoint.py @@ -6,7 +6,7 @@ import torch from litgpt import Config -from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint, copy_weights_hf_llama +from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint, copy_weights_hf_llama, qkv_reassemble def test_llama2_70b_conversion(): @@ -17,10 +17,10 @@ def test_llama2_70b_conversion(): "model.layers.0.mlp.gate_proj.weight": (28672, 8192), "model.layers.0.mlp.up_proj.weight": (28672, 8192), "model.layers.0.post_attention_layernorm.weight": (8192,), - "model.layers.0.self_attn.k_proj.weight": (1024, 8192), - "model.layers.0.self_attn.o_proj.weight": (8192, 8192), "model.layers.0.self_attn.q_proj.weight": (8192, 8192), + "model.layers.0.self_attn.k_proj.weight": (1024, 8192), "model.layers.0.self_attn.v_proj.weight": (1024, 8192), + "model.layers.0.self_attn.o_proj.weight": (8192, 8192), "model.layers.1.input_layernorm.weight": (8192,), "model.layers.1.mlp.down_proj.weight": (8192, 28672), "model.layers.1.mlp.gate_proj.weight": (28672, 8192), @@ -56,14 +56,14 @@ def test_llama2_70b_conversion(): weight_map = {k: torch.empty(s) for k, s in shapes.items()} copy_weights_hf_llama(config, qkv_weights, holder, weight_map) - # we are only testing 5 layers - assert len(qkv_weights) == 5 + # NOTE: there are 5 layers, but only in the first layer we have `q`, `k` and `v` + assert len(qkv_weights) == 1 # there are no loaded qkv weights assert all(v is None for qkv in qkv_weights.values() for v in qkv) # the shapes are correct holder = {k: tuple(t.shape) for k, t in holder.items()} assert holder == { - "transformer.h.0.attn.attn.weight": (10240, 8192), + "transformer.h.0.attn.qkv.weight": (10240, 8192), "transformer.h.0.attn.proj.weight": (8192, 8192), "transformer.h.0.mlp.fc_1.weight": (28672, 8192), "transformer.h.0.mlp.fc_2.weight": (28672, 8192), @@ -101,14 +101,18 @@ def test_llama2_70b_conversion(): } -def test_convert_hf_checkpoint(tmp_path): +@pytest.mark.parametrize("model_name", ("pythia-14m", "falcon-7b", "Llama-2-7b-hf", "phi-2")) +def test_convert_hf_checkpoint(tmp_path, model_name): with pytest.raises(ValueError, match="to contain .bin"): - convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name="pythia-14m") + convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) bin_file = tmp_path / "foo.bin" bin_file.touch() with mock.patch("litgpt.scripts.convert_hf_checkpoint.lazy_load") as load: - convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name="pythia-14m") + # bypass if-statement for weight tying + if model_name == "Llama-2-7b-hf": + load.return_value = {"model.embed_tokens.weight": torch.rand((10, 10))} + convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) load.assert_called_with(bin_file) assert {p.name for p in tmp_path.glob("*")} == {"foo.bin", "model_config.yaml", "lit_model.pth"} @@ -119,43 +123,40 @@ def test_convert_hf_checkpoint(tmp_path): def test_qkv_reassemble(): - from litgpt import Config - from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble - # MHA config = Config(n_embd=4, n_head=4) - qkv = torch.tensor( + qkv_interleaved = torch.tensor( [ [0, 1, 2, 3], # query - [4, 5, 6, 7], # query - [8, 9, 10, 11], # query - [12, 13, 14, 15], # query [16, 17, 18, 19], # key - [20, 21, 22, 23], # key - [24, 25, 26, 27], # key - [28, 29, 30, 31], # key [32, 33, 34, 35], # value + [4, 5, 6, 7], # query + [20, 21, 22, 23], # key [36, 37, 38, 39], # value + [8, 9, 10, 11], # query + [24, 25, 26, 27], # key [40, 41, 42, 43], # value + [12, 13, 14, 15], # query + [28, 29, 30, 31], # key [44, 45, 46, 47], # value ] ) - qkv_interleaved = qkv_reassemble(qkv, config) + qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( - qkv_interleaved, + qkv, torch.tensor( [ [0, 1, 2, 3], # query - [16, 17, 18, 19], # key - [32, 33, 34, 35], # value [4, 5, 6, 7], # query - [20, 21, 22, 23], # key - [36, 37, 38, 39], # value [8, 9, 10, 11], # query - [24, 25, 26, 27], # key - [40, 41, 42, 43], # value [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # key + [24, 25, 26, 27], # key [28, 29, 30, 31], # key + [32, 33, 34, 35], # value + [36, 37, 38, 39], # value + [40, 41, 42, 43], # value [44, 45, 46, 47], # value ] ), @@ -163,30 +164,30 @@ def test_qkv_reassemble(): # GQA config = Config(n_embd=4, n_head=4, n_query_groups=2) - qkv = torch.tensor( + qkv_interleaved = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query + [16, 17, 18, 19], # key + [24, 25, 26, 27], # value [8, 9, 10, 11], # query [12, 13, 14, 15], # query - [16, 17, 18, 19], # key [20, 21, 22, 23], # key - [24, 25, 26, 27], # value [28, 29, 30, 31], # value ] ) - qkv_interleaved = qkv_reassemble(qkv, config) + qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( - qkv_interleaved, + qkv, torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query - [16, 17, 18, 19], # key - [24, 25, 26, 27], # value [8, 9, 10, 11], # query [12, 13, 14, 15], # query + [16, 17, 18, 19], # key [20, 21, 22, 23], # key + [24, 25, 26, 27], # value [28, 29, 30, 31], # value ] ), @@ -194,7 +195,7 @@ def test_qkv_reassemble(): # MQA config = Config(n_embd=4, n_head=4, n_query_groups=1) - qkv = torch.tensor( + qkv_interleaved = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query @@ -204,9 +205,9 @@ def test_qkv_reassemble(): [20, 21, 22, 23], # value ] ) - qkv_interleaved = qkv_reassemble(qkv, config) + qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( - qkv_interleaved, + qkv, torch.tensor( [ [0, 1, 2, 3], # query diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index 5809f0063d..9e0cd93c35 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -15,6 +15,10 @@ from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.olmo import OlmoConfig, OlmoForCausalLM +from transformers.models.phi.configuration_phi import PhiConfig +from transformers.models.phi.modeling_phi import PhiForCausalLM +from transformers.models.phi3.configuration_phi3 import Phi3Config +from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from litgpt import GPT, Config @@ -27,13 +31,14 @@ copy_weights_llama, copy_weights_phi, copy_weights_qwen_2_5, - qkv_split, + qkv_reassemble, ) from tests.conftest import RunIf -def test_convert_lit_checkpoint(tmp_path): - ours_config = Config.from_name("Llama-2-7b-hf", block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128) +@pytest.mark.parametrize("model_name", ("pythia-14m", "falcon-7b", "Llama-2-7b-hf", "phi-2")) +def test_convert_lit_checkpoint(tmp_path, model_name): + ours_config = Config.from_name(model_name, block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "lit_model.pth" config_path = tmp_path / "model_config.yaml" @@ -70,7 +75,7 @@ def test_against_falcon_40b(): ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} - copy_weights_falcon("40b", theirs_state_dict, ours_state_dict) + copy_weights_falcon(ours_config, theirs_state_dict, ours_state_dict) theirs_model = FalconForCausalLM(theirs_config) # assign must be set to True for torch.testing.assert_close to pass @@ -105,7 +110,7 @@ def test_against_original_gpt_neox(): ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} - copy_weights_gpt_neox(theirs_state_dict, ours_state_dict) + copy_weights_gpt_neox(ours_config, theirs_state_dict, ours_state_dict) theirs_model = GPTNeoXForCausalLM(theirs_config) # strict=False because we don't save the rotary embeddings inv frequency keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) @@ -196,6 +201,7 @@ def test_against_mixtral(model_name): theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf")) def test_against_olmo(model_name): @@ -239,6 +245,7 @@ def test_against_olmo(model_name): theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() def test_against_original_open_llama_3b(): ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86) @@ -270,9 +277,6 @@ def test_against_original_open_llama_3b(): @torch.inference_mode() @pytest.mark.parametrize("model_name", ("phi-1_5", "phi-2")) def test_against_hf_phi(model_name): - from transformers.models.phi.configuration_phi import PhiConfig - from transformers.models.phi.modeling_phi import PhiForCausalLM - ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5 ) @@ -308,9 +312,6 @@ def test_against_hf_phi(model_name): @torch.inference_mode() @pytest.mark.parametrize("model_name", ("Phi-3-mini-4k-instruct",)) def test_against_hf_phi_3(model_name): - from transformers.models.phi3.configuration_phi3 import Phi3Config - from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM - ours_config = Config.from_name(model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256) T = 5 theirs_config = Phi3Config( @@ -425,7 +426,10 @@ def test_against_original_gemma(model_name, device, dtype): theirs_state_dict = {} copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True) theirs_model = GemmaForCausalLM(theirs_config).to(device) - theirs_model.load_state_dict(theirs_state_dict, strict=False,) + theirs_model.load_state_dict( + theirs_state_dict, + strict=False, + ) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) @@ -587,41 +591,41 @@ def test_against_original_qwen_2_5(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) -def test_qkv_split(): +def test_qkv_reassemble(): # MHA config = Config(n_embd=4, n_head=4) - qkv_interleaved = torch.tensor( + qkv = torch.tensor( [ [0, 1, 2, 3], # query - [16, 17, 18, 19], # key - [32, 33, 34, 35], # value [4, 5, 6, 7], # query - [20, 21, 22, 23], # key - [36, 37, 38, 39], # value [8, 9, 10, 11], # query - [24, 25, 26, 27], # key - [40, 41, 42, 43], # value [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # key + [24, 25, 26, 27], # key [28, 29, 30, 31], # key + [32, 33, 34, 35], # value + [36, 37, 38, 39], # value + [40, 41, 42, 43], # value [44, 45, 46, 47], # value ] ) - qkv = torch.cat(qkv_split(qkv_interleaved, config)) + qkv_interleaved = qkv_reassemble(qkv, config) torch.testing.assert_close( - qkv, + qkv_interleaved, torch.tensor( [ [0, 1, 2, 3], # query - [4, 5, 6, 7], # query - [8, 9, 10, 11], # query - [12, 13, 14, 15], # query [16, 17, 18, 19], # key - [20, 21, 22, 23], # key - [24, 25, 26, 27], # key - [28, 29, 30, 31], # key [32, 33, 34, 35], # value + [4, 5, 6, 7], # query + [20, 21, 22, 23], # key [36, 37, 38, 39], # value + [8, 9, 10, 11], # query + [24, 25, 26, 27], # key [40, 41, 42, 43], # value + [12, 13, 14, 15], # query + [28, 29, 30, 31], # key [44, 45, 46, 47], # value ] ), @@ -629,30 +633,30 @@ def test_qkv_split(): # GQA config = Config(n_embd=4, n_head=4, n_query_groups=2) - qkv_interleaved = torch.tensor( + qkv = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query - [16, 17, 18, 19], # key - [24, 25, 26, 27], # value [8, 9, 10, 11], # query [12, 13, 14, 15], # query + [16, 17, 18, 19], # key [20, 21, 22, 23], # key + [24, 25, 26, 27], # value [28, 29, 30, 31], # value ] ) - qkv = torch.cat(qkv_split(qkv_interleaved, config)) + qkv_interleaved = qkv_reassemble(qkv, config) torch.testing.assert_close( - qkv, + qkv_interleaved, torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query + [16, 17, 18, 19], # key + [24, 25, 26, 27], # value [8, 9, 10, 11], # query [12, 13, 14, 15], # query - [16, 17, 18, 19], # key [20, 21, 22, 23], # key - [24, 25, 26, 27], # value [28, 29, 30, 31], # value ] ), @@ -660,7 +664,7 @@ def test_qkv_split(): # MQA config = Config(n_embd=4, n_head=4, n_query_groups=1) - qkv_interleaved = torch.tensor( + qkv = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query @@ -670,9 +674,9 @@ def test_qkv_split(): [20, 21, 22, 23], # value ] ) - qkv = torch.cat(qkv_split(qkv_interleaved, config)) + qkv_interleaved = qkv_reassemble(qkv, config) torch.testing.assert_close( - qkv, + qkv_interleaved, torch.tensor( [ [0, 1, 2, 3], # query diff --git a/tests/test_generate_sequentially.py b/tests/test_generate_sequentially.py index 51bc9d2fe1..2d7603eb60 100644 --- a/tests/test_generate_sequentially.py +++ b/tests/test_generate_sequentially.py @@ -12,13 +12,13 @@ import pytest import torch import yaml -from tests.conftest import RunIf from lightning import Fabric from litgpt import Config from litgpt.generate.sequentially import layer_to_device, replace_device, sequential from litgpt.model import GPT, Block from litgpt.scripts.download import download_from_hub +from tests.conftest import RunIf @pytest.mark.parametrize( @@ -117,8 +117,8 @@ def _test_model_1device(accelerator): "cos": device_str, "sin": device_str, "lm_head.weight": device_str, - "transformer.h.0.attn.attn.bias": device_str, - "transformer.h.0.attn.attn.weight": device_str, + "transformer.h.0.attn.qkv.bias": device_str, + "transformer.h.0.attn.qkv.weight": device_str, "transformer.h.0.attn.proj.bias": device_str, "transformer.h.0.attn.proj.weight": device_str, "transformer.h.0.mlp.fc.bias": device_str, @@ -131,8 +131,8 @@ def _test_model_1device(accelerator): "transformer.h.0.norm_2.weight": device_str, "transformer.h.0.attn.kv_cache.k": device_str, "transformer.h.0.attn.kv_cache.v": device_str, - "transformer.h.1.attn.attn.bias": device_str, - "transformer.h.1.attn.attn.weight": device_str, + "transformer.h.1.attn.qkv.bias": device_str, + "transformer.h.1.attn.qkv.weight": device_str, "transformer.h.1.attn.proj.bias": device_str, "transformer.h.1.attn.proj.weight": device_str, "transformer.h.1.mlp.fc.bias": device_str, @@ -187,8 +187,8 @@ def test_model_forward_hooks(): "transformer.wte.weight": "cuda:0", "transformer.h.0.norm_1.weight": "cuda:0", "transformer.h.0.norm_1.bias": "cuda:0", - "transformer.h.0.attn.attn.weight": "cuda:0", - "transformer.h.0.attn.attn.bias": "cuda:0", + "transformer.h.0.attn.qkv.weight": "cuda:0", + "transformer.h.0.attn.qkv.bias": "cuda:0", "transformer.h.0.attn.proj.weight": "cuda:0", "transformer.h.0.attn.proj.bias": "cuda:0", "transformer.h.0.norm_2.weight": "cuda:0", @@ -199,8 +199,8 @@ def test_model_forward_hooks(): "transformer.h.0.mlp.proj.bias": "cuda:0", "transformer.h.1.norm_1.weight": "cuda:0", "transformer.h.1.norm_1.bias": "cuda:0", - "transformer.h.1.attn.attn.weight": "cuda:0", - "transformer.h.1.attn.attn.bias": "cuda:0", + "transformer.h.1.attn.qkv.weight": "cuda:0", + "transformer.h.1.attn.qkv.bias": "cuda:0", "transformer.h.1.attn.proj.weight": "cuda:0", "transformer.h.1.attn.proj.bias": "cuda:0", "transformer.h.1.norm_2.weight": "cuda:0", @@ -211,8 +211,8 @@ def test_model_forward_hooks(): "transformer.h.1.mlp.proj.bias": "cuda:0", "transformer.h.2.norm_1.weight": "cuda:0", "transformer.h.2.norm_1.bias": "cuda:0", - "transformer.h.2.attn.attn.weight": "cuda:0", - "transformer.h.2.attn.attn.bias": "cuda:0", + "transformer.h.2.attn.qkv.weight": "cuda:0", + "transformer.h.2.attn.qkv.bias": "cuda:0", "transformer.h.2.attn.proj.weight": "cuda:0", "transformer.h.2.attn.proj.bias": "cuda:0", "transformer.h.2.norm_2.weight": "cuda:0", @@ -223,8 +223,8 @@ def test_model_forward_hooks(): "transformer.h.2.mlp.proj.bias": "cuda:0", "transformer.h.3.norm_1.weight": "cuda:1", "transformer.h.3.norm_1.bias": "cuda:1", - "transformer.h.3.attn.attn.weight": "cuda:1", - "transformer.h.3.attn.attn.bias": "cuda:1", + "transformer.h.3.attn.qkv.weight": "cuda:1", + "transformer.h.3.attn.qkv.bias": "cuda:1", "transformer.h.3.attn.proj.weight": "cuda:1", "transformer.h.3.attn.proj.bias": "cuda:1", "transformer.h.3.norm_2.weight": "cuda:1", @@ -235,8 +235,8 @@ def test_model_forward_hooks(): "transformer.h.3.mlp.proj.bias": "cuda:1", "transformer.h.4.norm_1.weight": "cuda:1", "transformer.h.4.norm_1.bias": "cuda:1", - "transformer.h.4.attn.attn.weight": "cuda:1", - "transformer.h.4.attn.attn.bias": "cuda:1", + "transformer.h.4.attn.qkv.weight": "cuda:1", + "transformer.h.4.attn.qkv.bias": "cuda:1", "transformer.h.4.attn.proj.weight": "cuda:1", "transformer.h.4.attn.proj.bias": "cuda:1", "transformer.h.4.norm_2.weight": "cuda:1", @@ -247,8 +247,8 @@ def test_model_forward_hooks(): "transformer.h.4.mlp.proj.bias": "cuda:1", "transformer.h.5.norm_1.weight": "cuda:1", "transformer.h.5.norm_1.bias": "cuda:1", - "transformer.h.5.attn.attn.weight": "cuda:1", - "transformer.h.5.attn.attn.bias": "cuda:1", + "transformer.h.5.attn.qkv.weight": "cuda:1", + "transformer.h.5.attn.qkv.bias": "cuda:1", "transformer.h.5.attn.proj.weight": "cuda:1", "transformer.h.5.attn.proj.bias": "cuda:1", "transformer.h.5.norm_2.weight": "cuda:1", diff --git a/tests/test_lora.py b/tests/test_lora.py index 079d900d0b..c417d588a4 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout +from copy import deepcopy from io import StringIO from itertools import product from unittest import mock @@ -23,10 +24,19 @@ from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.lora import GPT as LoRAGPT +from litgpt.lora import ( + CausalSelfAttention, + Config, + LoRALinear, + LoRAQKVLinear, + lora_filter, + mark_only_lora_as_trainable, + merge_lora_weights, +) from litgpt.lora import CausalSelfAttention as LoRACausalSelfAttention -from litgpt.lora import Config, LoRALinear, LoRAQKVLinear, lora_filter, mark_only_lora_as_trainable, merge_lora_weights from litgpt.model import GPT as BaseGPT from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -100,11 +110,11 @@ def test_lora_mqa_gqa(): ) assert config.n_query_groups == config.n_head model = LoRAGPT(config) - attn = model.transformer.h[0].attn.attn + attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) - lora_ind = [0, 1, 6, 7, 12, 13, 18, 19, 4, 5, 10, 11, 16, 17, 22, 23] + lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] assert attn.linear.weight.shape == (24, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (16, 2) @@ -121,7 +131,7 @@ def test_lora_mqa_gqa(): # MQA config.n_query_groups = 1 model = LoRAGPT(config) - attn = model.transformer.h[0].attn.attn + attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) @@ -142,11 +152,11 @@ def test_lora_mqa_gqa(): # GQA config.n_query_groups = 2 model = LoRAGPT(config) - attn = model.transformer.h[0].attn.attn + attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) - lora_ind = [0, 1, 2, 3, 8, 9, 10, 11, 6, 7, 14, 15] + lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] assert attn.linear.weight.shape == (16, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (12, 2) @@ -169,12 +179,12 @@ def test_lora_filter(tmp_path): saved = torch.load(save_path)["model"] expected = { - "transformer.h.1.attn.attn.lora_B", - "transformer.h.2.attn.attn.lora_B", - "transformer.h.2.attn.attn.lora_A", - "transformer.h.1.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_B", + "transformer.h.1.attn.qkv.lora_B", + "transformer.h.2.attn.qkv.lora_B", + "transformer.h.2.attn.qkv.lora_A", + "transformer.h.1.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_B", } assert set(saved) == expected @@ -665,7 +675,7 @@ def test_against_original_gemma_2(model_name): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_gemma_2(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = LoRAGPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -740,29 +750,29 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.uint8": { - "transformer.h.0.attn.attn.linear.weight", + "transformer.h.0.attn.qkv.linear.weight", "transformer.h.0.attn.proj.linear.weight", "transformer.h.0.mlp.fc.linear.weight", "transformer.h.1.mlp.proj.linear.weight", "transformer.h.0.mlp.proj.linear.weight", - "transformer.h.1.attn.attn.linear.weight", + "transformer.h.1.attn.qkv.linear.weight", "lm_head.linear.weight", "transformer.h.1.attn.proj.linear.weight", "transformer.h.1.mlp.fc.linear.weight", }, "torch.float16": { - "transformer.h.0.attn.attn.lora_B", + "transformer.h.0.attn.qkv.lora_B", "transformer.h.0.norm_2.weight", "transformer.wte.weight", "transformer.wte.norm.weight", "transformer.wte.norm.bias", "transformer.h.1.mlp.fc.linear.bias", "transformer.ln_f.bias", - "transformer.h.1.attn.attn.lora_B", + "transformer.h.1.attn.qkv.lora_B", "transformer.h.1.attn.proj.linear.bias", "transformer.h.1.norm_1.weight", - "transformer.h.1.attn.attn.linear.bias", - "transformer.h.1.attn.attn.lora_A", + "transformer.h.1.attn.qkv.linear.bias", + "transformer.h.1.attn.qkv.lora_A", "transformer.h.1.norm_1.bias", "transformer.h.1.norm_2.bias", "transformer.h.0.attn.proj.linear.bias", @@ -771,11 +781,11 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa "transformer.h.0.mlp.fc.linear.bias", "transformer.h.0.norm_2.bias", "transformer.ln_f.weight", - "transformer.h.0.attn.attn.lora_A", + "transformer.h.0.attn.qkv.lora_A", "transformer.h.1.norm_2.weight", "transformer.h.1.mlp.proj.linear.bias", "transformer.h.0.norm_1.weight", - "transformer.h.0.attn.attn.linear.bias", + "transformer.h.0.attn.qkv.linear.bias", }, } @@ -787,10 +797,10 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.float16": { - "transformer.h.1.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_B", - "transformer.h.1.attn.attn.lora_B", + "transformer.h.1.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_B", + "transformer.h.1.attn.qkv.lora_B", } } @@ -835,11 +845,12 @@ def test_lora_model_fsdp_init(): def test_zero_pad_cpu_and_mocked_mps(): - in_features = 128 - out_features = 384 head_size = 64 n_head = 12 n_query_groups = 3 + in_features = 128 + kv_embed_dim = in_features // (n_head // n_query_groups) + out_features = in_features + 2 * kv_embed_dim enable_lora = [True, False, True] r = 4 @@ -850,12 +861,12 @@ def test_zero_pad_cpu_and_mocked_mps(): n_head=n_head, n_query_groups=n_query_groups, r=r, - enable_lora=enable_lora + enable_lora=enable_lora, ) batch_size = 64 seq_len = 64 - embed_dim = 320 + embed_dim = 160 x = torch.randn(batch_size, seq_len, embed_dim) result_cpu = model.zero_pad(x) @@ -868,3 +879,29 @@ def test_zero_pad_cpu_and_mocked_mps(): assert result_cpu.shape == result_mps.shape, "Shape mismatch between CPU and MPS" assert torch.allclose(result_cpu, result_mps), "Tensor values mismatch between CPU and MPS" + + + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + lora_r=8, + lora_alpha=16, + lora_dropout=0.1 + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.linear.weight"] = make_qkv_interleaved(state_dict.pop("qkv.linear.weight"), config) + state_dict["attn.linear.bias"] = make_qkv_interleaved(state_dict.pop("qkv.linear.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) diff --git a/tests/test_model.py b/tests/test_model.py index 9a21f0d34d..abd1a767bf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2,9 +2,9 @@ from copy import deepcopy from functools import partial +from unittest import mock import pytest -from unittest import mock import torch from lightning import Fabric from lightning.fabric.utilities.imports import _IS_WINDOWS @@ -31,8 +31,8 @@ from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM import litgpt.config as config_module -from litgpt.model import batched_index_copy_ from litgpt import GPT, Config +from litgpt.model import CausalSelfAttention, batched_index_copy_ from litgpt.scripts.convert_hf_checkpoint import ( copy_weights_falcon, copy_weights_gemma_2, @@ -41,6 +41,7 @@ copy_weights_phi, copy_weights_qwen_2_5, ) +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -97,7 +98,7 @@ def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residua state_dict = {} theirs_model = GPTNeoXForCausalLM(theirs_config).to(device) # load the hf initialization into our model - copy_weights_gpt_neox(state_dict, theirs_model.state_dict()) + copy_weights_gpt_neox(ours_config, state_dict, theirs_model.state_dict()) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -152,7 +153,7 @@ def test_against_hf_falcon(kwargs, device, dtype): theirs_model = FalconForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} - copy_weights_falcon(kwargs["name"], state_dict, theirs_state_dict) + copy_weights_falcon(ours_config, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -556,6 +557,7 @@ def test_against_hf_mixtral(model_name): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf")) @pytest.mark.parametrize( @@ -614,6 +616,7 @@ def test_against_olmo(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"), @@ -779,7 +782,7 @@ def test_against_original_gemma_2(model_name, device, dtype): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_gemma_2(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -1298,3 +1301,24 @@ def test_batched_index_copy_modes(): val_3_mps = val_3 batched_index_copy_(t3_mps, dim_3, idx_3_mps, val_3_mps) assert torch.allclose(t3_cpu, t3_mps), "Mismatch with negative dimension on mocked MPS" + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.weight"] = make_qkv_interleaved(state_dict.pop("qkv.weight"), config) + state_dict["attn.bias"] = make_qkv_interleaved(state_dict.pop("qkv.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) From d2ed13c34a8e3763d69c29961e462b5e4a783e17 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Sat, 28 Dec 2024 21:47:20 +0300 Subject: [PATCH 17/18] Bump PyTorch, PyTorch-Lightning and BnB versions (#1893) --- pyproject.toml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1dd1e53743..77e343417c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,9 +9,9 @@ readme = "README.md" license = { file = "LICENSE" } dependencies = [ - "torch>=2.2.0,<=2.4.1", + "torch>=2.5.0,<2.6.0", "numpy<2.0", - "lightning==2.4.0", + "lightning>=2.5.0,<2.6.0", "jsonargparse[signatures]>=4.30.1,<=4.32.1", # 4.33 does not seem to be compatible with Python 3.9 "huggingface_hub>=0.23.5", # download models "safetensors>=0.4.3", # download models @@ -38,7 +38,8 @@ test = [ "protobuf>=4.23.4", ] all = [ - "bitsandbytes==0.42.0", # quantization + "bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32'", # quantization + "bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin'", # quantization "sentencepiece>=0.2.0", # llama-based models "requests>=2.31.0", # litgpt.data "litdata==0.2.17", # litgpt.data From 93fc1b8f5259c21c6298090712d5e5be7bbdc732 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> Date: Mon, 30 Dec 2024 20:06:34 +0300 Subject: [PATCH 18/18] Pin version of mistune in check links workflow (#1895) --- .github/workflows/check-links.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/check-links.yml b/.github/workflows/check-links.yml index 96e8b860d4..8efc74b738 100644 --- a/.github/workflows/check-links.yml +++ b/.github/workflows/check-links.yml @@ -23,9 +23,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + pip install "mistune<3.1" # a newer version is incompatible with nbconvert pip install pytest pytest-check-links - name: Check links run: | pytest --check-links README.md --check-links-ignore "http*" - pytest --check-links tutorials --check-links-ignore "http*" \ No newline at end of file + pytest --check-links tutorials --check-links-ignore "http*"