From 549be4cc056cc2a60b2a45cba094c1642e3f3ace Mon Sep 17 00:00:00 2001 From: "Jason T. Greene" Date: Sun, 22 Dec 2024 09:25:10 -0600 Subject: [PATCH] [Bugfix] Fix fully sharded LoRAs with Mixtral (#11390) Signed-off-by: Jason Greene --- tests/lora/test_mixtral.py | 4 +++- vllm/lora/layers.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 150221dfce6ab..797a495201d33 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -62,8 +62,9 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): @pytest.mark.parametrize("tp_size", [4]) +@pytest.mark.parametrize("fully_shard", [True, False]) def test_mixtral_lora_all_target_modules(mixtral_lora_files_all_target_modules, - tp_size): + tp_size, fully_shard): """This LoRA model has all supported Mixtral target modules""" if torch.cuda.device_count() < tp_size: @@ -82,6 +83,7 @@ def test_mixtral_lora_all_target_modules(mixtral_lora_files_all_target_modules, max_loras=4, distributed_executor_backend="ray", tensor_parallel_size=tp_size, + fully_sharded_loras=fully_shard, max_lora_rank=32, ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index a6c93a3d8bfe9..85164c2165a3c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -425,8 +425,9 @@ def forward(self, input_): if self.base_layer.skip_bias_add else None) return output, output_bias + # ReplicatedLinear should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. @classmethod - @_not_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module,