diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 29ecf37808205..df04c3e0390cb 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model +from vllm.platforms import current_platform class ContextIDInfo(TypedDict): @@ -64,13 +65,16 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): @pytest.fixture def dist_init(): temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend="nccl", - ) + + backend = "nccl" + if current_platform.is_cpu(): + backend = "gloo" + + init_distributed_environment(world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend=backend) initialize_model_parallel(1, 1) yield cleanup_dist_env_and_memory(shutdown_ray=True) @@ -80,13 +84,15 @@ def dist_init(): def dist_init_torch_only(): if torch.distributed.is_initialized(): return + backend = "nccl" + if current_platform.is_cpu(): + backend = "gloo" + temp_file = tempfile.mkstemp()[1] - torch.distributed.init_process_group( - backend="nccl", - world_size=1, - rank=0, - init_method=f"file://{temp_file}", - ) + torch.distributed.init_process_group(world_size=1, + rank=0, + init_method=f"file://{temp_file}", + backend=backend) @pytest.fixture diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index a113e3f7abc1e..4369cb6bd6630 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -48,10 +48,19 @@ torch.float32: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } + +pytestmark = pytest.mark.skipif( + not (current_platform.is_cuda_alike() or current_platform.is_cpu()), + reason="Backend not supported") + CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +CPU_DEVICES = ["cpu"] + +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES + # We will launch different triton kernels between the prefill and decode # stages, so we need to verify this. prefill stage(True) or decode stage(False) STAGES = [True, False] @@ -194,14 +203,15 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA # device, see: https://github.com/triton-lang/triton/issues/2925 # Same below. - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 @@ -296,13 +306,15 @@ def create_random_embedding_layer(): # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings_with_new_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) max_loras = 8 punica_wrapper = PunicaWrapper(8192, 256, device) @@ -432,13 +444,15 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) max_loras = 8 punica_wrapper = PunicaWrapper(8192, 256, device) @@ -563,13 +577,15 @@ def _pretest(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_replicated(dist_init, num_loras, device, stage, bias_enabled) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 @@ -675,13 +691,15 @@ def create_random_linear_replicated_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, device, stage, bias_enabled) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 @@ -797,13 +815,15 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) @pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, device, stage, bias_enabled) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 @@ -955,6 +975,8 @@ class FakeConfig: @pytest.mark.parametrize("rotary_dim", [None, 32]) @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only CUDA backends are supported") def test_rotary_embedding_long_context(dist_init, num_loras, device, scaling_factors, max_position, is_neox_style, rotary_dim, head_size, diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 8d109b2c81503..6b4bea32886bc 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -17,6 +17,7 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.platforms import current_platform EMBEDDING_MODULES = { "embed_tokens": "input_embeddings", @@ -28,9 +29,12 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +CPU_DEVICES = ["cpu"] +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES -@pytest.mark.parametrize("device", CUDA_DEVICES) + +@pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( os.path.join(sql_lora_files, "adapter_model.safetensors")) @@ -113,7 +117,7 @@ def test_replace_submodules(dist_init, dummy_model): manager = LoRAModelManager( model, 1, 1, 1, LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), - torch.device("cuda")) + torch.device(DEVICES[0])) model = manager.model assert isinstance(model.get_submodule("dense1"), @@ -125,7 +129,7 @@ def test_replace_submodules(dist_init, dummy_model): RowParallelLinearWithLoRA) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lora_model_manager(dist_init, dummy_model, device): model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] @@ -186,7 +190,7 @@ def test_lora_model_manager(dist_init, dummy_model, device): assert manager.punica_wrapper.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] @@ -278,7 +282,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): assert manager.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lru_lora_model_manager(dist_init, dummy_model, device): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager @@ -408,7 +412,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): assert manager.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) @@ -487,7 +491,7 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, device) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): # Should remove every LoRA not specified in the request. @@ -563,7 +567,7 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, device) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_packed_loras(dist_init, dummy_model_gate_up, device): model = dummy_model_gate_up model.supported_lora_modules = ["gate_up_proj"] diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index dddc299da446b..31237acd549eb 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -5,6 +5,7 @@ import vllm from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1" @@ -31,7 +32,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, @pytest.mark.parametrize("tp_size", [4]) def test_mixtral_lora(mixtral_lora_files, tp_size): """Original test, the LoRA model has the common target modules, not all""" - if torch.cuda.device_count() < tp_size: + if torch.cuda.device_count( + ) < tp_size and tp_size > 1 and current_platform.is_cuda_alike(): pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") prompts = [ diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 66b5f82bbb97d..af31748f8cbd6 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -7,12 +7,8 @@ import pytest import torch -from vllm.lora.ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice -from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, sgmv_shrink) from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, @@ -110,7 +106,10 @@ MAX_RANKS = [32] SCALES = [0.5] SEED = [0] -CUDA_DEVICES = [f"cuda:{0}"] + +CUDA_DEVICES = ["cuda:0"] +CPU_DEVICES = ["cpu"] +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES def assert_close(a, b): @@ -130,7 +129,7 @@ def assert_close(a, b): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( batches: int, num_loras: int, @@ -220,7 +219,7 @@ def test_punica_sgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_bgmv( batches: int, num_loras: int, @@ -294,7 +293,7 @@ def test_punica_bgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_expand_nslices( batches: int, num_loras: int, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 3b20033271d26..f280c758c6fd1 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -6,13 +6,30 @@ import pytest import torch -# Enable custom op register -import vllm.lora.ops.bgmv_expand -import vllm.lora.ops.bgmv_expand_slice -import vllm.lora.ops.bgmv_shrink -import vllm.lora.ops.sgmv_expand -import vllm.lora.ops.sgmv_expand_slice -import vllm.lora.ops.sgmv_shrink # noqa: F401 +from vllm.triton_utils import HAS_TRITON + +# Enable custom op register if we're using custom ops +if HAS_TRITON: + import vllm.lora.ops.triton.bgmv_expand + import vllm.lora.ops.triton.bgmv_expand_slice + import vllm.lora.ops.triton.bgmv_shrink + import vllm.lora.ops.triton.sgmv_expand + import vllm.lora.ops.triton.sgmv_expand_slice + import vllm.lora.ops.triton.sgmv_shrink # noqa: F401 + + # Unlike test_punica_sizes.py, we directly utilize custom op for + # testing, which verifies the correct registration of these ops. + bgmv_expand = torch.ops.vllm.bgmv_expand + bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice + bgmv_shrink = torch.ops.vllm.bgmv_shrink + sgmv_expand = torch.ops.vllm.sgmv_expand + sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice + sgmv_shrink = torch.ops.vllm.sgmv_shrink +else: + from vllm.lora.ops.default.lora_ops import ( # type: ignore + bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, @@ -26,7 +43,10 @@ MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] SCALES = [0.5] SEED = [0] -CUDA_DEVICES = [f"cuda:{0}"] + +CUDA_DEVICES = ["cuda:0"] +CPU_DEVICES = ["cpu"] +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES def assert_close(a, b): @@ -38,16 +58,6 @@ def assert_close(a, b): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) -# Unlike test_punica_sizes.py, we directly utilize custom op for -# testing, which verifies the correct registration of these ops. -bgmv_expand = torch.ops.vllm.bgmv_expand -bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice -bgmv_shrink = torch.ops.vllm.bgmv_shrink -sgmv_expand = torch.ops.vllm.sgmv_expand -sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice -sgmv_shrink = torch.ops.vllm.sgmv_shrink - - @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) @@ -56,7 +66,7 @@ def assert_close(a, b): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( batches: int, num_loras: int, @@ -146,7 +156,7 @@ def test_punica_sgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_bgmv( batches: int, num_loras: int, @@ -222,7 +232,7 @@ def test_punica_bgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_expand_nslices( batches: int, num_loras: int, diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 5432fa4ad0d3a..c2590594a277b 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -72,7 +72,8 @@ def format_prompt_tuples(prompt): @pytest.mark.parametrize("tp_size", [1]) def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, tp_size): - if num_gpus_available < tp_size: + if num_gpus_available < tp_size and \ + tp_size > 1 and current_platform.is_cuda_alike(): pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") llm = vllm.LLM( diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py index e69de29bb2d1d..7a13eabeb6074 100644 --- a/vllm/lora/ops/__init__.py +++ b/vllm/lora/ops/__init__.py @@ -0,0 +1,18 @@ +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.triton.bgmv_expand import bgmv_expand + from vllm.lora.ops.triton.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.triton.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.triton.sgmv_expand import sgmv_expand + from vllm.lora.ops.triton.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.triton.sgmv_shrink import sgmv_shrink +else: + from vllm.lora.ops.default.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + +__all__ = [ + "bgmv_expand", "bgmv_expand_slice", "bgmv_shrink", "sgmv_expand", + "sgmv_expand_slice", "sgmv_shrink" +] diff --git a/vllm/lora/ops/default/__init__.py b/vllm/lora/ops/default/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/lora/ops/default/lora_ops.py b/vllm/lora/ops/default/lora_ops.py new file mode 100644 index 0000000000000..5f5aafd516159 --- /dev/null +++ b/vllm/lora/ops/default/lora_ops.py @@ -0,0 +1,113 @@ +import torch + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + if add_inputs: + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] + else: + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, + slice_offset, slice_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] diff --git a/vllm/lora/ops/triton/__init__.py b/vllm/lora/ops/triton/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/triton/bgmv_expand.py similarity index 100% rename from vllm/lora/ops/bgmv_expand.py rename to vllm/lora/ops/triton/bgmv_expand.py diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/triton/bgmv_expand_slice.py similarity index 100% rename from vllm/lora/ops/bgmv_expand_slice.py rename to vllm/lora/ops/triton/bgmv_expand_slice.py diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/triton/bgmv_shrink.py similarity index 100% rename from vllm/lora/ops/bgmv_shrink.py rename to vllm/lora/ops/triton/bgmv_shrink.py diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/triton/sgmv_expand.py similarity index 100% rename from vllm/lora/ops/sgmv_expand.py rename to vllm/lora/ops/triton/sgmv_expand.py diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/triton/sgmv_expand_slice.py similarity index 100% rename from vllm/lora/ops/sgmv_expand_slice.py rename to vllm/lora/ops/triton/sgmv_expand_slice.py diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/triton/sgmv_shrink.py similarity index 100% rename from vllm/lora/ops/sgmv_shrink.py rename to vllm/lora/ops/triton/sgmv_shrink.py diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/triton/utils.py similarity index 100% rename from vllm/lora/ops/utils.py rename to vllm/lora/ops/triton/utils.py diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 563d1181d6fcb..fc95a40582b1e 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -9,15 +9,8 @@ import torch -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.lora.ops.bgmv_expand import bgmv_expand - from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, sgmv_shrink) if TYPE_CHECKING: # avoid circuit import diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 420aaf8a1b4cd..db8a03abf0cc5 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -2,8 +2,8 @@ import weakref from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, - Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, + TypeVar, Union) import torch from torch import nn @@ -12,10 +12,14 @@ from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, @@ -49,6 +53,8 @@ class ModelInputForCPU(ModelRunnerInputBase): virtual_engine: Optional[int] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -57,6 +63,8 @@ def as_broadcastable_tensor_dict( "input_positions": self.input_positions, "token_type_ids": self.token_type_ids, "multi_modal_kwargs": self.multi_modal_kwargs, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -145,7 +153,11 @@ def __init__(self, or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.enable_lora = self.runner.lora_config is not None self.input_data = ModelInputForCPUBuilder.ModelInputData( self.runner.model_config.uses_mrope) self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( @@ -184,15 +196,28 @@ def build(self) -> ModelInputForCPU: attn_metadata = self.att_metadata_builder.build( input_data.seq_lens, input_data.query_lens, -1, -1) - return self.model_input_cls( - input_tokens=input_tokens, - input_positions=input_positions, - token_type_ids=token_type_ids, - seq_lens=input_data.seq_lens, - query_lens=input_data.query_lens, - attn_metadata=attn_metadata, - multi_modal_kwargs=multi_modal_kwargs, - ) + is_prompt = (self.seq_group_metadata_list[0].is_prompt + if self.seq_group_metadata_list else None) + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(seq.lora_request + for seq in self.seq_group_metadata_list + if seq.lora_request is not None) + + lora_mapping = self._prepare_lora_input( + self.seq_group_metadata_list, is_prompt) + + return self.model_input_cls(input_tokens=input_tokens, + input_positions=input_positions, + token_type_ids=token_type_ids, + seq_lens=input_data.seq_lens, + query_lens=input_data.query_lens, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + lora_mapping=lora_mapping, + lora_requests=lora_requests) def _build_input_data(self): for seq_group_metadata in self.seq_group_metadata_list: @@ -383,6 +408,24 @@ def _compute_multi_modal_input(self, self.input_data.multi_modal_placeholder_maps[modality].extend( placeholder_map) + def _prepare_lora_input( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + is_prefill: bool) -> LoRAMapping: + index_mapping = [] + prompt_mapping = [] + for seq in seq_group_metadata_list: + lora_id = seq.lora_int_id + query_len = seq.token_chunk_size + + index_mapping += [lora_id] * query_len + prompt_mapping += [lora_id] * ( + query_len if seq.sampling_params + and seq.sampling_params.prompt_logprobs is not None else 1) + + return LoRAMapping(index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=is_prefill) + class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): """ @@ -433,10 +476,41 @@ def __init__( # Lazy initialization. self.model: nn.Module # Set after init_Model + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -461,6 +535,37 @@ def sampler(self): def vocab_size(self) -> int: return self.model_config.get_vocab_size() + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( @@ -517,6 +622,12 @@ def execute_model( raise ValueError( "CPU worker does not support multi-step execution.") + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + model_executable = self.model multimodal_kwargs = {} diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 4fad1a3f4caeb..cb78d32828245 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Set, Tuple, Type import torch import torch.distributed @@ -11,14 +11,14 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerBase, +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) logger = init_logger(__name__) @@ -111,7 +111,7 @@ def get_cache_block_size( return dtype_size * total -class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): +class CPUWorker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a CPU socket. Each worker is associated with a single CPU socket. The worker is @@ -266,6 +266,18 @@ def initialize_cache(self, num_gpu_blocks: int, # Initialize the cache. self._init_cache_engine() + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: """Raise errors if the num_cpu_blocks is invalid. """