From a9b7212b185db4ad129b5d4e920b8dd690c5e1d8 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 22 Dec 2024 14:48:29 +0200 Subject: [PATCH 1/3] The maximum running requests can be higher than the max batch size, causing an error inside the mamba cache manager, setting the max num seqs as twice as the max batch size, ensures that new requests will have spare space in the mamba cache manager Signed-off-by: mzusman --- vllm/model_executor/models/jamba.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 91786db5ddc96..c6b6aa1a3107f 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -422,17 +422,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - if self.scheduler_config is not None and \ - not self.model_config.enforce_eager: - if self.scheduler_config.max_num_seqs > \ + + effective_max_batch_size = int(self.vllm_config.scheduler_config.max_num_seqs * 2) + if not self.model_config.enforce_eager \ + and effective_max_batch_size <= \ vllm_config.compilation_config.max_capture_size: - self.max_batch_size = \ - vllm_config.compilation_config.max_capture_size - else: - self.max_batch_size = vllm_config.pad_for_cudagraph( - self.scheduler_config.max_num_seqs) + self.max_batch_size = vllm_config.pad_for_cudagraph( + effective_max_batch_size) else: - self.max_batch_size = 8192 + 2 + self.max_batch_size = effective_max_batch_size def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) From 48efd2f847a332f5506353f41bb23a001e543ff4 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 22 Dec 2024 14:59:51 +0200 Subject: [PATCH 2/3] Format Signed-off-by: mzusman --- vllm/model_executor/models/jamba.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index c6b6aa1a3107f..1debb198560b1 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -423,12 +423,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - effective_max_batch_size = int(self.vllm_config.scheduler_config.max_num_seqs * 2) + effective_max_batch_size = int( + self.vllm_config.scheduler_config.max_num_seqs * 2) if not self.model_config.enforce_eager \ and effective_max_batch_size <= \ vllm_config.compilation_config.max_capture_size: self.max_batch_size = vllm_config.pad_for_cudagraph( - effective_max_batch_size) + effective_max_batch_size) else: self.max_batch_size = effective_max_batch_size From 7d2fef50de3d78976ca8f3d715a2bedc418b85cb Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 22 Dec 2024 15:40:36 +0200 Subject: [PATCH 3/3] Introduce this variable through envar Signed-off-by: mzusman --- vllm/envs.py | 3 +++ vllm/model_executor/models/jamba.py | 5 ++++- vllm/model_executor/models/mamba.py | 20 +++++++++++--------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 18870c1c6b51a..cdc4a8ca81fe6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -72,6 +72,7 @@ VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False + VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER: float = 1.5 def get_default_cache_root(): @@ -466,6 +467,8 @@ def get_default_config_root(): lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), "VLLM_DISABLE_COMPILE_CACHE": lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), + "VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER": + lambda: float(os.getenv("VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER", "1.5")), } # end-env-vars-definition diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 1debb198560b1..92804d19cd704 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -5,6 +5,7 @@ from torch import nn from transformers import JambaConfig +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig @@ -424,7 +425,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model.make_empty_intermediate_tensors) effective_max_batch_size = int( - self.vllm_config.scheduler_config.max_num_seqs * 2) + self.vllm_config.scheduler_config.max_num_seqs * \ + envs.VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER + ) if not self.model_config.enforce_eager \ and effective_max_batch_size <= \ vllm_config.compilation_config.max_capture_size: diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 06c8d9723cd01..ee5c26cde308b 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -5,6 +5,7 @@ from torch import nn from transformers import MambaConfig +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -195,17 +196,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) - if self.scheduler_config is not None and \ - not self.model_config.enforce_eager: - if self.scheduler_config.max_num_seqs > \ + + effective_max_batch_size = int( + self.vllm_config.scheduler_config.max_num_seqs * \ + envs.VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER + ) + if not self.model_config.enforce_eager \ + and effective_max_batch_size <= \ vllm_config.compilation_config.max_capture_size: - self.max_batch_size = \ - vllm_config.compilation_config.max_capture_size - else: - self.max_batch_size = vllm_config.pad_for_cudagraph( - self.scheduler_config.max_num_seqs) + self.max_batch_size = vllm_config.pad_for_cudagraph( + effective_max_batch_size) else: - self.max_batch_size = 8192 + 2 + self.max_batch_size = effective_max_batch_size def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids)