From 4a5df028043a805939a3cb0b6d7cf5c2db604c27 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 20 Nov 2024 19:20:21 +0000 Subject: [PATCH 01/19] Added non-triton SGMV and BGMV ops (not kernels yet) Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 1 + vllm/lora/ops/default/lora_ops.py | 119 ++++++++++++++++++ vllm/lora/ops/{ => triton}/__init__.py | 0 vllm/lora/ops/{ => triton}/bgmv_expand.py | 0 .../ops/{ => triton}/bgmv_expand_slice.py | 0 vllm/lora/ops/{ => triton}/bgmv_shrink.py | 0 vllm/lora/ops/{ => triton}/sgmv_expand.py | 0 .../ops/{ => triton}/sgmv_expand_slice.py | 0 vllm/lora/ops/{ => triton}/sgmv_shrink.py | 0 vllm/lora/ops/{ => triton}/utils.py | 0 vllm/lora/punica.py | 14 ++- 11 files changed, 128 insertions(+), 6 deletions(-) create mode 100644 vllm/lora/ops/default/lora_ops.py rename vllm/lora/ops/{ => triton}/__init__.py (100%) rename vllm/lora/ops/{ => triton}/bgmv_expand.py (100%) rename vllm/lora/ops/{ => triton}/bgmv_expand_slice.py (100%) rename vllm/lora/ops/{ => triton}/bgmv_shrink.py (100%) rename vllm/lora/ops/{ => triton}/sgmv_expand.py (100%) rename vllm/lora/ops/{ => triton}/sgmv_expand_slice.py (100%) rename vllm/lora/ops/{ => triton}/sgmv_shrink.py (100%) rename vllm/lora/ops/{ => triton}/utils.py (100%) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 6afe80219fe07..b7bc34d7b490e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1495,6 +1495,7 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits + print("punica", logits.dtype) # LogitsProcessorWithLoRA always using bgmv self.punica_wrapper.add_lora_logits(logits, hidden_states, self.lora_a_stacked, diff --git a/vllm/lora/ops/default/lora_ops.py b/vllm/lora/ops/default/lora_ops.py new file mode 100644 index 0000000000000..8ad32dd4a77b7 --- /dev/null +++ b/vllm/lora/ops/default/lora_ops.py @@ -0,0 +1,119 @@ +import torch + +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_expand( + inputs, + lora_b_weights, + output_tensor, + exploded_indices, + add_inputs + ) + + +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True +): + selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + inputs = inputs.to(dtype=torch.float16) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:] += outputs[:] + else: + output_tensor[:] = outputs[:] + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_shrink( + inputs, + lora_a_weights, + output_tensor, + exploded_indices, + scaling + ) + +def bgmv_shrink( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0 +): + selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:] = scaling * outputs[:] + +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_expand_slice( + inputs, + lora_b_weights, + output_tensor, + exploded_indices, + slice_offset, + slice_size, + add_inputs + ) + + +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True +): + selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + inputs = inputs.to(dtype=torch.float16) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset+slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset+slice_size] = outputs[:] \ No newline at end of file diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/triton/__init__.py similarity index 100% rename from vllm/lora/ops/__init__.py rename to vllm/lora/ops/triton/__init__.py diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/triton/bgmv_expand.py similarity index 100% rename from vllm/lora/ops/bgmv_expand.py rename to vllm/lora/ops/triton/bgmv_expand.py diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/triton/bgmv_expand_slice.py similarity index 100% rename from vllm/lora/ops/bgmv_expand_slice.py rename to vllm/lora/ops/triton/bgmv_expand_slice.py diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/triton/bgmv_shrink.py similarity index 100% rename from vllm/lora/ops/bgmv_shrink.py rename to vllm/lora/ops/triton/bgmv_shrink.py diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/triton/sgmv_expand.py similarity index 100% rename from vllm/lora/ops/sgmv_expand.py rename to vllm/lora/ops/triton/sgmv_expand.py diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/triton/sgmv_expand_slice.py similarity index 100% rename from vllm/lora/ops/sgmv_expand_slice.py rename to vllm/lora/ops/triton/sgmv_expand_slice.py diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/triton/sgmv_shrink.py similarity index 100% rename from vllm/lora/ops/sgmv_shrink.py rename to vllm/lora/ops/triton/sgmv_shrink.py diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/triton/utils.py similarity index 100% rename from vllm/lora/ops/utils.py rename to vllm/lora/ops/triton/utils.py diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 082041f390750..7eaaf2237d72e 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -12,12 +12,14 @@ from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.lora.ops.bgmv_expand import bgmv_expand - from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.sgmv_shrink import sgmv_shrink + from vllm.lora.ops.triton.bgmv_expand import bgmv_expand + from vllm.lora.ops.triton.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.triton.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.triton.sgmv_expand import sgmv_expand + from vllm.lora.ops.triton.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.triton.sgmv_shrink import sgmv_shrink +else: + from vllm.lora.ops.default.lora_ops import * if TYPE_CHECKING: # avoid circuit import From 8c45f11701a923cc23b02c06d0ddb35b517a58df Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 21 Nov 2024 11:40:53 +0000 Subject: [PATCH 02/19] Removed extra print Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index b7bc34d7b490e..6afe80219fe07 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1495,7 +1495,6 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits - print("punica", logits.dtype) # LogitsProcessorWithLoRA always using bgmv self.punica_wrapper.add_lora_logits(logits, hidden_states, self.lora_a_stacked, From 676991b0c6b4af600c70d557ebbc085699328eea Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 21 Nov 2024 11:48:27 +0000 Subject: [PATCH 03/19] Minor changes Signed-off-by: Akshat Tripathi --- tests/lora/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 29ecf37808205..d9484fe15e237 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -69,7 +69,7 @@ def dist_init(): rank=0, distributed_init_method=f"file://{temp_file}", local_rank=0, - backend="nccl", + backend="gloo", # TODO: Find a way to easily switch between this and nccl ) initialize_model_parallel(1, 1) yield @@ -82,10 +82,10 @@ def dist_init_torch_only(): return temp_file = tempfile.mkstemp()[1] torch.distributed.init_process_group( - backend="nccl", world_size=1, rank=0, init_method=f"file://{temp_file}", + backend="gloo", # TODO: Find a way to easily switch between this and nccl ) From 3457a73f318fc211972faf8b5500f1356a6dbcd5 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 22 Nov 2024 12:27:02 +0000 Subject: [PATCH 04/19] Made some minor shape-based fixes to the kernels Signed-off-by: Akshat Tripathi --- vllm/lora/ops/default/lora_ops.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/lora/ops/default/lora_ops.py b/vllm/lora/ops/default/lora_ops.py index 8ad32dd4a77b7..cd12f3659f478 100644 --- a/vllm/lora/ops/default/lora_ops.py +++ b/vllm/lora/ops/default/lora_ops.py @@ -30,14 +30,18 @@ def bgmv_expand( lora_indices_tensor: torch.Tensor, add_inputs: bool = True ): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze() - inputs = inputs.to(dtype=torch.float16) + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + if add_inputs: - output_tensor[:] += outputs[:] + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] else: - output_tensor[:] = outputs[:] + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] def sgmv_shrink( inputs: torch.Tensor, @@ -68,10 +72,10 @@ def bgmv_shrink( lora_indices_tensor: torch.Tensor, scaling: float = 1.0 ): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - output_tensor[:] = scaling * outputs[:] + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] def sgmv_expand_slice( inputs: torch.Tensor, @@ -109,8 +113,8 @@ def bgmv_expand_slice( slice_size: int, add_inputs: bool = True ): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze() - inputs = inputs.to(dtype=torch.float16) + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) if add_inputs: From 898e4e486b0b92b33226b68f5c3b91ec0eb124ef Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 3 Dec 2024 12:21:27 +0000 Subject: [PATCH 05/19] Added multi lora execution to the CPU backend Signed-off-by: Akshat Tripathi --- vllm/executor/cpu_executor.py | 3 - vllm/worker/cpu_model_runner.py | 103 +++++++++++++++++++++++++++++++- vllm/worker/cpu_worker.py | 19 +++++- 3 files changed, 118 insertions(+), 7 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 4ceb5a837dd7f..c0cd473954ac3 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -26,9 +26,6 @@ class CPUExecutor(ExecutorBase): def _init_executor(self) -> None: assert self.device_config.device_type == "cpu" - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - assert self.lora_config is None, "cpu backend doesn't support LoRA" # # Environment variables for CPU executor diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d3e1202c15e61..08b447bebb05a 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -3,7 +3,7 @@ from collections import defaultdict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, - TypeVar, Union) + TypeVar, Set, Union) import torch from torch import nn @@ -27,6 +27,13 @@ _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +from vllm.model_executor.models import supports_lora + +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager + + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -48,6 +55,8 @@ class ModelInputForCPU(ModelRunnerInputBase): virtual_engine: Optional[int] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -55,6 +64,8 @@ def as_broadcastable_tensor_dict( "input_tokens": self.input_tokens, "input_positions": self.input_positions, "multi_modal_kwargs": self.multi_modal_kwargs, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -116,6 +127,7 @@ def __init__(self, self.block_size = self.runner.block_size self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.enable_lora = self.runner.lora_config is not None def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) @@ -136,11 +148,21 @@ def build(self) -> ModelInputForCPU: self.seq_group_metadata_list) seq_lens = None + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(seq.lora_request for seq in self.seq_group_metadata_list if seq.lora_request is not None) + + lora_mapping = self._prepare_lora_input(self.seq_group_metadata_list, is_prompt) + return self.model_input_cls( input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, multi_modal_kwargs=multi_modal_kwargs, + lora_mapping=lora_mapping, + lora_requests=lora_requests, # query_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill # just use seq_lens instead. @@ -419,6 +441,23 @@ def _prepare_decode( attn_metadata, ) + def _prepare_lora_input(self, + seq_group_metadata_list: List[SequenceGroupMetadata], is_prefill: bool) -> LoRAMapping: + index_mapping = [] + prompt_mapping = [] + for seq in seq_group_metadata_list: + lora_id = seq.lora_int_id + query_len = seq.token_chunk_size + + index_mapping += [lora_id] * query_len + prompt_mapping += [lora_id] * (query_len if seq.sampling_params and seq.sampling_params.prompt_logprobs is not None else 1) + + return LoRAMapping( + index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=is_prefill + ) + class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): """ @@ -464,9 +503,34 @@ def __init__( # Lazy initialization. self.model: nn.Module # Set after init_Model + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) + + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + max_pos_embeddings = self.model.config.max_position_embeddings # TODO: Add the multimodal case + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() def _prepare_model_input_tensors( self, @@ -483,6 +547,37 @@ def _prepare_model_input_tensors( builder.add_seq_group(seq_group_metadata) return builder.build() # type: ignore + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): @@ -535,6 +630,12 @@ def execute_model( if num_steps > 1: raise ValueError( "CPU worker does not support multi-step execution.") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) model_executable = self.model execute_model_kwargs = { diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index bc9164bd9d5df..8832d18ea74ff 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type, Set import torch import torch.distributed @@ -11,6 +11,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -18,7 +19,7 @@ from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerBase, + WorkerBase, WorkerInput) logger = init_logger(__name__) @@ -111,7 +112,7 @@ def get_cache_block_size( return dtype_size * total -class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): +class CPUWorker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a CPU socket. Each worker is associated with a single CPU socket. The worker is @@ -250,6 +251,18 @@ def initialize_cache(self, num_gpu_blocks: int, # Initialize the cache. self._init_cache_engine() + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: """Raise errors if the num_cpu_blocks is invalid. From c537cc9ac9fc72d28f3a6fa33a539ae6034ba919 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 3 Dec 2024 15:28:54 +0000 Subject: [PATCH 06/19] Added __init__.py files Signed-off-by: Akshat Tripathi --- vllm/lora/ops/__init__.py | 0 vllm/lora/ops/default/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/lora/ops/__init__.py create mode 100644 vllm/lora/ops/default/__init__.py diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/lora/ops/default/__init__.py b/vllm/lora/ops/default/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d From 8a17bc58dde1ba36fa0ceb1eda76a1f293b20a5f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 3 Dec 2024 16:09:26 +0000 Subject: [PATCH 07/19] Changed lora op importing to work irrespective of backend Signed-off-by: Akshat Tripathi --- tests/lora/test_punica_sizes.py | 14 ++++++++------ tests/lora/test_punica_variation.py | 14 ++++++++------ vllm/lora/ops/__init__.py | 18 ++++++++++++++++++ vllm/lora/punica.py | 19 ++++++++----------- 4 files changed, 42 insertions(+), 23 deletions(-) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 66b5f82bbb97d..7750da94b46ce 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -7,12 +7,14 @@ import pytest import torch -from vllm.lora.ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice -from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops import ( + bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink +) from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 52b82f25d23e1..003088fd08eab 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -6,12 +6,14 @@ import pytest import torch -from vllm.lora.ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice -from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops import ( + bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink +) from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py index e69de29bb2d1d..ebde15779343d 100644 --- a/vllm/lora/ops/__init__.py +++ b/vllm/lora/ops/__init__.py @@ -0,0 +1,18 @@ +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.triton.bgmv_expand import bgmv_expand + from vllm.lora.ops.triton.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.triton.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.triton.sgmv_expand import sgmv_expand + from vllm.lora.ops.triton.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.triton.sgmv_shrink import sgmv_shrink +else: # TODO: Replace wiht HAS_XLA + from vllm.lora.ops.default.lora_ops import ( + bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink + ) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 7eaaf2237d72e..a1032b806ab4e 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -9,17 +9,14 @@ import torch -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.lora.ops.triton.bgmv_expand import bgmv_expand - from vllm.lora.ops.triton.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.triton.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.triton.sgmv_expand import sgmv_expand - from vllm.lora.ops.triton.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.triton.sgmv_shrink import sgmv_shrink -else: - from vllm.lora.ops.default.lora_ops import * +from vllm.lora.ops import ( + bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink +) if TYPE_CHECKING: # avoid circuit import From 733cecc7b09bf086551504ccf9e3ca9312287ff0 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 3 Dec 2024 16:10:00 +0000 Subject: [PATCH 08/19] Removed comment Signed-off-by: Akshat Tripathi --- vllm/lora/ops/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py index ebde15779343d..aecb1f25ba6de 100644 --- a/vllm/lora/ops/__init__.py +++ b/vllm/lora/ops/__init__.py @@ -7,7 +7,7 @@ from vllm.lora.ops.triton.sgmv_expand import sgmv_expand from vllm.lora.ops.triton.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.triton.sgmv_shrink import sgmv_shrink -else: # TODO: Replace wiht HAS_XLA +else: from vllm.lora.ops.default.lora_ops import ( bgmv_expand, bgmv_expand_slice, From 5b8d04a35260383b3f6633a277622085bf3a234e Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 3 Dec 2024 16:15:24 +0000 Subject: [PATCH 09/19] Added multimodal lora support Signed-off-by: Akshat Tripathi --- vllm/worker/cpu_model_runner.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 08b447bebb05a..4ebacd747b251 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -27,13 +27,12 @@ _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) -from vllm.model_executor.models import supports_lora +from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager - if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -508,14 +507,24 @@ def __init__( def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) - + if self.lora_config: assert supports_lora( self.model ), f"{self.model.__class__.__name__} does not support LoRA yet." - max_pos_embeddings = self.model.config.max_position_embeddings # TODO: Add the multimodal case - + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, From 7ad9748ed967382b060a769ea5eb062925e70ffc Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 3 Dec 2024 16:15:37 +0000 Subject: [PATCH 10/19] Ran formatter Signed-off-by: Akshat Tripathi --- tests/lora/conftest.py | 6 +- tests/lora/test_punica_sizes.py | 10 +- tests/lora/test_punica_variation.py | 10 +- vllm/lora/ops/__init__.py | 11 +- vllm/lora/ops/default/lora_ops.py | 162 ++++++++++++---------------- vllm/lora/punica.py | 10 +- vllm/worker/cpu_model_runner.py | 32 +++--- vllm/worker/cpu_worker.py | 5 +- 8 files changed, 104 insertions(+), 142 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index d9484fe15e237..ed86bdb4087f4 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -69,7 +69,8 @@ def dist_init(): rank=0, distributed_init_method=f"file://{temp_file}", local_rank=0, - backend="gloo", # TODO: Find a way to easily switch between this and nccl + backend= + "gloo", # TODO: Find a way to easily switch between this and nccl ) initialize_model_parallel(1, 1) yield @@ -85,7 +86,8 @@ def dist_init_torch_only(): world_size=1, rank=0, init_method=f"file://{temp_file}", - backend="gloo", # TODO: Find a way to easily switch between this and nccl + backend= + "gloo", # TODO: Find a way to easily switch between this and nccl ) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 7750da94b46ce..7b569077453eb 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -7,14 +7,8 @@ import pytest import torch -from vllm.lora.ops import ( - bgmv_expand, - bgmv_expand_slice, - bgmv_shrink, - sgmv_expand, - sgmv_expand_slice, - sgmv_shrink -) +from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, sgmv_shrink) from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 003088fd08eab..21dcec4456f5a 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -6,14 +6,8 @@ import pytest import torch -from vllm.lora.ops import ( - bgmv_expand, - bgmv_expand_slice, - bgmv_shrink, - sgmv_expand, - sgmv_expand_slice, - sgmv_shrink -) +from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, sgmv_shrink) from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py index aecb1f25ba6de..5447c82b534ba 100644 --- a/vllm/lora/ops/__init__.py +++ b/vllm/lora/ops/__init__.py @@ -8,11 +8,6 @@ from vllm.lora.ops.triton.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.triton.sgmv_shrink import sgmv_shrink else: - from vllm.lora.ops.default.lora_ops import ( - bgmv_expand, - bgmv_expand_slice, - bgmv_shrink, - sgmv_expand, - sgmv_expand_slice, - sgmv_shrink - ) + from vllm.lora.ops.default.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) diff --git a/vllm/lora/ops/default/lora_ops.py b/vllm/lora/ops/default/lora_ops.py index cd12f3659f478..205977a81c249 100644 --- a/vllm/lora/ops/default/lora_ops.py +++ b/vllm/lora/ops/default/lora_ops.py @@ -1,48 +1,42 @@ import torch -def sgmv_expand( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - add_inputs: bool = False -): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - - bgmv_expand( - inputs, - lora_b_weights, - output_tensor, - exploded_indices, - add_inputs - ) - - -def bgmv_expand( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True -): + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - + limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: limit = 1 - + if add_inputs: output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] else: output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] - + + def sgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, @@ -55,69 +49,55 @@ def sgmv_shrink( token_nums: int, scaling: float, ): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - - bgmv_shrink( - inputs, - lora_a_weights, - output_tensor, - exploded_indices, - scaling - ) - -def bgmv_shrink( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0 -): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] - -def sgmv_expand_slice( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - slice_offset: int, - slice_size: int, - add_inputs: bool = False -): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - - bgmv_expand_slice( - inputs, - lora_b_weights, - output_tensor, - exploded_indices, - slice_offset, - slice_size, - add_inputs - ) - - -def bgmv_expand_slice( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True -): + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, + slice_offset, slice_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - + if add_inputs: - output_tensor[:, slice_offset:slice_offset+slice_size] += outputs[:] + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] else: - output_tensor[:, slice_offset:slice_offset+slice_size] = outputs[:] \ No newline at end of file + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index a1032b806ab4e..1cf87aeca3022 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -9,14 +9,8 @@ import torch -from vllm.lora.ops import ( - bgmv_expand, - bgmv_expand_slice, - bgmv_shrink, - sgmv_expand, - sgmv_expand_slice, - sgmv_shrink -) +from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, sgmv_shrink) if TYPE_CHECKING: # avoid circuit import diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 4ebacd747b251..925e6e0d68737 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -151,9 +151,12 @@ def build(self) -> ModelInputForCPU: lora_requests = set() lora_mapping = None if self.enable_lora: - lora_requests = set(seq.lora_request for seq in self.seq_group_metadata_list if seq.lora_request is not None) + lora_requests = set(seq.lora_request + for seq in self.seq_group_metadata_list + if seq.lora_request is not None) - lora_mapping = self._prepare_lora_input(self.seq_group_metadata_list, is_prompt) + lora_mapping = self._prepare_lora_input( + self.seq_group_metadata_list, is_prompt) return self.model_input_cls( input_tokens=input_tokens, @@ -440,22 +443,23 @@ def _prepare_decode( attn_metadata, ) - def _prepare_lora_input(self, - seq_group_metadata_list: List[SequenceGroupMetadata], is_prefill: bool) -> LoRAMapping: + def _prepare_lora_input( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + is_prefill: bool) -> LoRAMapping: index_mapping = [] prompt_mapping = [] for seq in seq_group_metadata_list: lora_id = seq.lora_int_id query_len = seq.token_chunk_size - + index_mapping += [lora_id] * query_len - prompt_mapping += [lora_id] * (query_len if seq.sampling_params and seq.sampling_params.prompt_logprobs is not None else 1) - - return LoRAMapping( - index_mapping=tuple(index_mapping), - prompt_mapping=tuple(prompt_mapping), - is_prefill=is_prefill - ) + prompt_mapping += [lora_id] * ( + query_len if seq.sampling_params + and seq.sampling_params.prompt_logprobs is not None else 1) + + return LoRAMapping(index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=is_prefill) class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): @@ -556,7 +560,7 @@ def _prepare_model_input_tensors( builder.add_seq_group(seq_group_metadata) return builder.build() # type: ignore - + def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") @@ -639,7 +643,7 @@ def execute_model( if num_steps > 1: raise ValueError( "CPU worker does not support multi-step execution.") - + if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 8832d18ea74ff..fe8d25a79d4e4 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -18,8 +18,7 @@ from vllm.worker.cpu_embedding_model_runner import CPUEmbeddingModelRunner from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - WorkerBase, +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) logger = init_logger(__name__) @@ -251,7 +250,7 @@ def initialize_cache(self, num_gpu_blocks: int, # Initialize the cache. self._init_cache_engine() - + def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) From d639b1a62f8ee1d66028416dc375a3c039fadf0e Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 3 Dec 2024 16:17:20 +0000 Subject: [PATCH 11/19] Changed test backend back to nccl Signed-off-by: Akshat Tripathi --- tests/lora/conftest.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index ed86bdb4087f4..3d48dce7abcf3 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -69,8 +69,7 @@ def dist_init(): rank=0, distributed_init_method=f"file://{temp_file}", local_rank=0, - backend= - "gloo", # TODO: Find a way to easily switch between this and nccl + backend="nccl" ) initialize_model_parallel(1, 1) yield @@ -86,8 +85,7 @@ def dist_init_torch_only(): world_size=1, rank=0, init_method=f"file://{temp_file}", - backend= - "gloo", # TODO: Find a way to easily switch between this and nccl + backend="nccl" ) From 56c07d7e3fa0dcefdd7a6511232f6330f29f394d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 4 Dec 2024 11:45:28 +0000 Subject: [PATCH 12/19] Updated lora tests to work with cpu Signed-off-by: Akshat Tripathi --- tests/lora/conftest.py | 16 +++++++++++++--- tests/lora/test_layers.py | 23 +++++++++++++++++------ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 3d48dce7abcf3..a5dae3a9bf88a 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -10,6 +10,7 @@ import vllm from vllm.config import LoRAConfig +from vllm.platforms import current_platform from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) @@ -64,12 +65,17 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): @pytest.fixture def dist_init(): temp_file = tempfile.mkstemp()[1] + + backend = "nccl" + if current_platform.is_cpu(): + backend = "gloo" + init_distributed_environment( world_size=1, rank=0, distributed_init_method=f"file://{temp_file}", local_rank=0, - backend="nccl" + backend=backend ) initialize_model_parallel(1, 1) yield @@ -79,13 +85,17 @@ def dist_init(): @pytest.fixture def dist_init_torch_only(): if torch.distributed.is_initialized(): - return + return + backend = "nccl" + if current_platform.is_cpu(): + backend = "gloo" + temp_file = tempfile.mkstemp()[1] torch.distributed.init_process_group( world_size=1, rank=0, init_method=f"file://{temp_file}", - backend="nccl" + backend=backend ) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 15e576cb065c7..c8097865cb4f7 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from vllm.config import LoRAConfig +from vllm.platforms import current_platform from vllm.lora.fully_sharded_layers import ( ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, @@ -48,10 +49,20 @@ torch.float32: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } + +pytestmark = pytest.mark.skipif(not ( + current_platform.is_cuda_alike() or + current_platform.is_cpu() +)) + CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +CPU_DEVICES = ["cpu"] + +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES + # We will launch different triton kernels between the prefill and decode # stages, so we need to verify this. prefill stage(True) or decode stage(False) STAGES = [True, False] @@ -194,7 +205,7 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: @@ -296,7 +307,7 @@ def create_random_embedding_layer(): # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings_with_new_embeddings(dist_init, num_loras, device, @@ -432,7 +443,7 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, @@ -563,7 +574,7 @@ def _pretest(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) def test_linear_replicated(dist_init, num_loras, device, stage) -> None: @@ -667,7 +678,7 @@ def create_random_linear_replicated_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, device, stage) -> None: @@ -782,7 +793,7 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, device, stage) -> None: From 4059ee71617036a9cbb171e5cea6037f4220575f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 4 Dec 2024 17:25:47 +0000 Subject: [PATCH 13/19] Fixed lora op dtypes Signed-off-by: Akshat Tripathi --- vllm/lora/ops/default/lora_ops.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/default/lora_ops.py b/vllm/lora/ops/default/lora_ops.py index 205977a81c249..fafa9a1f932bf 100644 --- a/vllm/lora/ops/default/lora_ops.py +++ b/vllm/lora/ops/default/lora_ops.py @@ -23,7 +23,7 @@ def bgmv_expand(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1).to(dtype=output_tensor.dtype) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) @@ -61,7 +61,8 @@ def bgmv_shrink(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1).to(dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] @@ -93,7 +94,7 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1).to(dtype=output_tensor.dtype) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) From 94730429fa5e619bafaeba80d53cdf36f69ab6ed Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 4 Dec 2024 17:26:03 +0000 Subject: [PATCH 14/19] Updated tests to work with the CPU backend Signed-off-by: Akshat Tripathi --- tests/lora/test_layers.py | 26 +++++++++++++++++++------- tests/lora/test_lora_manager.py | 20 ++++++++++++-------- tests/lora/test_punica_sizes.py | 11 +++++++---- tests/lora/test_punica_variation.py | 11 +++++++---- 4 files changed, 45 insertions(+), 23 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index c8097865cb4f7..230c1e02df3c8 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -53,7 +53,7 @@ pytestmark = pytest.mark.skipif(not ( current_platform.is_cuda_alike() or current_platform.is_cpu() -)) +), reason="Backend not supported") CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) @@ -212,7 +212,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA # device, see: https://github.com/triton-lang/triton/issues/2925 # Same below. - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 @@ -313,7 +314,9 @@ def create_random_embedding_layer(): def test_embeddings_with_new_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) max_loras = 8 punica_wrapper = PunicaWrapper(8192, 256, device) @@ -449,7 +452,9 @@ def create_random_embedding_layer(): def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) max_loras = 8 punica_wrapper = PunicaWrapper(8192, 256, device) @@ -578,7 +583,9 @@ def _pretest(): @pytest.mark.parametrize("stage", STAGES) def test_linear_replicated(dist_init, num_loras, device, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 @@ -683,7 +690,9 @@ def create_random_linear_replicated_layer(): def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, device, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 @@ -798,7 +807,9 @@ def create_random_linear_parallel_layer(): def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, device, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 @@ -944,6 +955,7 @@ class FakeConfig: @pytest.mark.parametrize("rotary_dim", [None, 32]) @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only CUDA backends are supported") def test_rotary_embedding_long_context(dist_init, num_loras, device, scaling_factors, max_position, is_neox_style, rotary_dim, head_size, diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 8d109b2c81503..6b4bea32886bc 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -17,6 +17,7 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.platforms import current_platform EMBEDDING_MODULES = { "embed_tokens": "input_embeddings", @@ -28,9 +29,12 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +CPU_DEVICES = ["cpu"] +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES -@pytest.mark.parametrize("device", CUDA_DEVICES) + +@pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( os.path.join(sql_lora_files, "adapter_model.safetensors")) @@ -113,7 +117,7 @@ def test_replace_submodules(dist_init, dummy_model): manager = LoRAModelManager( model, 1, 1, 1, LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), - torch.device("cuda")) + torch.device(DEVICES[0])) model = manager.model assert isinstance(model.get_submodule("dense1"), @@ -125,7 +129,7 @@ def test_replace_submodules(dist_init, dummy_model): RowParallelLinearWithLoRA) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lora_model_manager(dist_init, dummy_model, device): model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] @@ -186,7 +190,7 @@ def test_lora_model_manager(dist_init, dummy_model, device): assert manager.punica_wrapper.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] @@ -278,7 +282,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): assert manager.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lru_lora_model_manager(dist_init, dummy_model, device): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager @@ -408,7 +412,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): assert manager.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) @@ -487,7 +491,7 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, device) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): # Should remove every LoRA not specified in the request. @@ -563,7 +567,7 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, device) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_packed_loras(dist_init, dummy_model_gate_up, device): model = dummy_model_gate_up model.supported_lora_modules = ["gate_up_proj"] diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 7b569077453eb..af31748f8cbd6 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -106,7 +106,10 @@ MAX_RANKS = [32] SCALES = [0.5] SEED = [0] -CUDA_DEVICES = [f"cuda:{0}"] + +CUDA_DEVICES = ["cuda:0"] +CPU_DEVICES = ["cpu"] +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES def assert_close(a, b): @@ -126,7 +129,7 @@ def assert_close(a, b): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( batches: int, num_loras: int, @@ -216,7 +219,7 @@ def test_punica_sgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_bgmv( batches: int, num_loras: int, @@ -290,7 +293,7 @@ def test_punica_bgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_expand_nslices( batches: int, num_loras: int, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 21dcec4456f5a..d04e20e883575 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -21,7 +21,10 @@ MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] SCALES = [0.5] SEED = [0] -CUDA_DEVICES = [f"cuda:{0}"] + +CUDA_DEVICES = ["cuda:0"] +CPU_DEVICES = ["cpu"] +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES def assert_close(a, b): @@ -41,7 +44,7 @@ def assert_close(a, b): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( batches: int, num_loras: int, @@ -131,7 +134,7 @@ def test_punica_sgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_bgmv( batches: int, num_loras: int, @@ -207,7 +210,7 @@ def test_punica_bgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_expand_nslices( batches: int, num_loras: int, From e00b6bd6d42ecbb29c420e533339bdd14f8da28f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 5 Dec 2024 15:50:06 +0000 Subject: [PATCH 15/19] Removed redundant .squeeze call in lora ops Signed-off-by: Akshat Tripathi --- vllm/lora/ops/default/lora_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/default/lora_ops.py b/vllm/lora/ops/default/lora_ops.py index fafa9a1f932bf..3dbc16843a006 100644 --- a/vllm/lora/ops/default/lora_ops.py +++ b/vllm/lora/ops/default/lora_ops.py @@ -23,7 +23,7 @@ def bgmv_expand(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1).to(dtype=output_tensor.dtype) + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) @@ -61,7 +61,7 @@ def bgmv_shrink(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1).to(dtype=output_tensor.dtype) + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) @@ -94,7 +94,7 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1).to(dtype=output_tensor.dtype) + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) From 22fc007504d1ee33228df0cac6b628e83339f95c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 5 Dec 2024 17:45:52 +0000 Subject: [PATCH 16/19] Stopped skipping some e2e tests Signed-off-by: Akshat Tripathi --- tests/lora/test_layers.py | 1 - tests/lora/test_llama.py | 4 +++- tests/lora/test_mixtral.py | 4 +++- tests/lora/test_quant_model.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 230c1e02df3c8..21ea96a3f27a4 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -9,7 +9,6 @@ import torch.nn.functional as F from vllm.config import LoRAConfig -from vllm.platforms import current_platform from vllm.lora.fully_sharded_layers import ( ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py index e2a4f1ed0496a..6266a1f767f6a 100644 --- a/tests/lora/test_llama.py +++ b/tests/lora/test_llama.py @@ -6,6 +6,7 @@ import vllm from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform MODEL_PATH = "meta-llama/Llama-2-7b-hf" @@ -40,7 +41,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: @pytest.mark.parametrize("tp_size", [1, 2, 4]) def test_llama_lora(sql_lora_files, tp_size, num_gpus_available): if num_gpus_available < tp_size: - pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + if tp_size > 1 and current_platform.is_cuda_alike(): + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") llm = vllm.LLM(MODEL_PATH, enable_lora=True, diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index dddc299da446b..d048ca5f06948 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -5,6 +5,7 @@ import vllm from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1" @@ -32,7 +33,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, def test_mixtral_lora(mixtral_lora_files, tp_size): """Original test, the LoRA model has the common target modules, not all""" if torch.cuda.device_count() < tp_size: - pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + if tp_size > 1 and current_platform.is_cuda_alike(): + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") prompts = [ "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501 diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 5432fa4ad0d3a..60c1607e27d0c 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -73,7 +73,8 @@ def format_prompt_tuples(prompt): def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, tp_size): if num_gpus_available < tp_size: - pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + if tp_size > 1 and current_platform.is_cuda_alike(): + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") llm = vllm.LLM( model=model.model_path, From 1272784231fe53e77fb66ce9f554a7c3d2fc283c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 6 Dec 2024 12:20:35 +0000 Subject: [PATCH 17/19] Fixed case where the lora_b tensor has 4 dims not 3 Signed-off-by: Akshat Tripathi --- vllm/lora/ops/default/lora_ops.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/lora/ops/default/lora_ops.py b/vllm/lora/ops/default/lora_ops.py index 3dbc16843a006..2b934d199cdd1 100644 --- a/vllm/lora/ops/default/lora_ops.py +++ b/vllm/lora/ops/default/lora_ops.py @@ -24,6 +24,8 @@ def bgmv_expand(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) @@ -62,6 +64,8 @@ def bgmv_shrink(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) @@ -96,6 +100,8 @@ def bgmv_expand_slice(inputs: torch.Tensor, add_inputs: bool = True): selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) if add_inputs: From fb5e02f6616c3bbcee3007755212b65a6239ca6c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 9 Dec 2024 13:01:50 +0000 Subject: [PATCH 18/19] Formatted code Signed-off-by: Akshat Tripathi --- tests/lora/conftest.py | 32 ++++++++++++++----------------- tests/lora/test_layers.py | 10 +++++----- tests/lora/test_llama.py | 6 +++--- tests/lora/test_mixtral.py | 6 +++--- tests/lora/test_quant_model.py | 6 +++--- vllm/lora/ops/__init__.py | 5 +++++ vllm/lora/ops/default/lora_ops.py | 9 ++++++--- vllm/worker/cpu_model_runner.py | 16 +++++++--------- vllm/worker/cpu_worker.py | 2 +- 9 files changed, 47 insertions(+), 45 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index a5dae3a9bf88a..df04c3e0390cb 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -10,7 +10,6 @@ import vllm from vllm.config import LoRAConfig -from vllm.platforms import current_platform from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) @@ -21,6 +20,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model +from vllm.platforms import current_platform class ContextIDInfo(TypedDict): @@ -65,18 +65,16 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): @pytest.fixture def dist_init(): temp_file = tempfile.mkstemp()[1] - + backend = "nccl" if current_platform.is_cpu(): backend = "gloo" - - init_distributed_environment( - world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend=backend - ) + + init_distributed_environment(world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend=backend) initialize_model_parallel(1, 1) yield cleanup_dist_env_and_memory(shutdown_ray=True) @@ -85,18 +83,16 @@ def dist_init(): @pytest.fixture def dist_init_torch_only(): if torch.distributed.is_initialized(): - return + return backend = "nccl" if current_platform.is_cpu(): backend = "gloo" - + temp_file = tempfile.mkstemp()[1] - torch.distributed.init_process_group( - world_size=1, - rank=0, - init_method=f"file://{temp_file}", - backend=backend - ) + torch.distributed.init_process_group(world_size=1, + rank=0, + init_method=f"file://{temp_file}", + backend=backend) @pytest.fixture diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 21ea96a3f27a4..c5638727abefc 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -49,10 +49,9 @@ torch.bfloat16: (3e-2, 2e-2), } -pytestmark = pytest.mark.skipif(not ( - current_platform.is_cuda_alike() or - current_platform.is_cpu() -), reason="Backend not supported") +pytestmark = pytest.mark.skipif( + not (current_platform.is_cuda_alike() or current_platform.is_cpu()), + reason="Backend not supported") CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) @@ -954,7 +953,8 @@ class FakeConfig: @pytest.mark.parametrize("rotary_dim", [None, 32]) @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) -@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Only CUDA backends are supported") +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only CUDA backends are supported") def test_rotary_embedding_long_context(dist_init, num_loras, device, scaling_factors, max_position, is_neox_style, rotary_dim, head_size, diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py index 6266a1f767f6a..edcf3e019c78c 100644 --- a/tests/lora/test_llama.py +++ b/tests/lora/test_llama.py @@ -40,9 +40,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: @pytest.mark.parametrize("tp_size", [1, 2, 4]) def test_llama_lora(sql_lora_files, tp_size, num_gpus_available): - if num_gpus_available < tp_size: - if tp_size > 1 and current_platform.is_cuda_alike(): - pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + if num_gpus_available < tp_size and \ + tp_size > 1 and current_platform.is_cuda_alike(): + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") llm = vllm.LLM(MODEL_PATH, enable_lora=True, diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index d048ca5f06948..31237acd549eb 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -32,9 +32,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, @pytest.mark.parametrize("tp_size", [4]) def test_mixtral_lora(mixtral_lora_files, tp_size): """Original test, the LoRA model has the common target modules, not all""" - if torch.cuda.device_count() < tp_size: - if tp_size > 1 and current_platform.is_cuda_alike(): - pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + if torch.cuda.device_count( + ) < tp_size and tp_size > 1 and current_platform.is_cuda_alike(): + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") prompts = [ "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501 diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 60c1607e27d0c..c2590594a277b 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -72,9 +72,9 @@ def format_prompt_tuples(prompt): @pytest.mark.parametrize("tp_size", [1]) def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, tp_size): - if num_gpus_available < tp_size: - if tp_size > 1 and current_platform.is_cuda_alike(): - pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + if num_gpus_available < tp_size and \ + tp_size > 1 and current_platform.is_cuda_alike(): + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") llm = vllm.LLM( model=model.model_path, diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py index 5447c82b534ba..7a13eabeb6074 100644 --- a/vllm/lora/ops/__init__.py +++ b/vllm/lora/ops/__init__.py @@ -11,3 +11,8 @@ from vllm.lora.ops.default.lora_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) + +__all__ = [ + "bgmv_expand", "bgmv_expand_slice", "bgmv_shrink", "sgmv_expand", + "sgmv_expand_slice", "sgmv_shrink" +] diff --git a/vllm/lora/ops/default/lora_ops.py b/vllm/lora/ops/default/lora_ops.py index 2b934d199cdd1..5f5aafd516159 100644 --- a/vllm/lora/ops/default/lora_ops.py +++ b/vllm/lora/ops/default/lora_ops.py @@ -23,7 +23,8 @@ def bgmv_expand(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) @@ -63,7 +64,8 @@ def bgmv_shrink(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) @@ -98,7 +100,8 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) inputs = inputs.to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 925e6e0d68737..a3df19596c3fc 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -2,8 +2,8 @@ import weakref from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, - TypeVar, Set, Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) import torch from torch import nn @@ -11,10 +11,14 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, @@ -27,12 +31,6 @@ _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) -from vllm.model_executor.models import supports_lora, supports_multimodal - -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager - if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -516,7 +514,7 @@ def load_model(self) -> None: assert supports_lora( self.model ), f"{self.model.__class__.__name__} does not support LoRA yet." - + if supports_multimodal(self.model): logger.warning("Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model.") diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index fe8d25a79d4e4..1d87bebee7d08 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional, Tuple, Type, Set +from typing import Dict, List, Optional, Set, Tuple, Type import torch import torch.distributed From c3e9afe07f6346c8ca42d95fe18c99495162bd37 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 10 Dec 2024 13:01:49 +0000 Subject: [PATCH 19/19] Fixed formatting issues Signed-off-by: Akshat Tripathi --- tests/lora/test_punica_variation.py | 7 ++++--- vllm/worker/cpu_model_runner.py | 26 +++++++++++--------------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 9e8f3196f2731..f280c758c6fd1 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -26,9 +26,9 @@ sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice sgmv_shrink = torch.ops.vllm.sgmv_shrink else: - from vllm.lora.ops.default.lora_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) + from vllm.lora.ops.default.lora_ops import ( # type: ignore + bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) from vllm.platforms import current_platform @@ -57,6 +57,7 @@ def assert_close(a, b): }[a.dtype] torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 593f70b7fd564..db8a03abf0cc5 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -2,7 +2,7 @@ import weakref from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, TypeVar, Union) import torch @@ -24,7 +24,6 @@ MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -198,7 +197,7 @@ def build(self) -> ModelInputForCPU: input_data.seq_lens, input_data.query_lens, -1, -1) is_prompt = (self.seq_group_metadata_list[0].is_prompt - if self.seq_group_metadata_list else None) + if self.seq_group_metadata_list else None) # LoRA data. lora_requests = set() lora_mapping = None @@ -210,17 +209,15 @@ def build(self) -> ModelInputForCPU: lora_mapping = self._prepare_lora_input( self.seq_group_metadata_list, is_prompt) - return self.model_input_cls( - input_tokens=input_tokens, - input_positions=input_positions, - token_type_ids=token_type_ids, - seq_lens=input_data.seq_lens, - query_lens=input_data.query_lens, - attn_metadata=attn_metadata, - multi_modal_kwargs=multi_modal_kwargs, - lora_mapping=lora_mapping, - lora_requests=lora_requests - ) + return self.model_input_cls(input_tokens=input_tokens, + input_positions=input_positions, + token_type_ids=token_type_ids, + seq_lens=input_data.seq_lens, + query_lens=input_data.query_lens, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + lora_mapping=lora_mapping, + lora_requests=lora_requests) def _build_input_data(self): for seq_group_metadata in self.seq_group_metadata_list: @@ -411,7 +408,6 @@ def _compute_multi_modal_input(self, self.input_data.multi_modal_placeholder_maps[modality].extend( placeholder_map) - def _prepare_lora_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], is_prefill: bool) -> LoRAMapping: