diff --git a/vllm/envs.py b/vllm/envs.py index ab12a7b48dc53..be5d9985b63a4 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -69,6 +69,7 @@ VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False + VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 def get_default_cache_root(): @@ -452,6 +453,8 @@ def get_default_config_root(): # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))), + "VLLM_LOG_BATCHSIZE_INTERVAL": + lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), } # end-env-vars-definition diff --git a/vllm/forward_context.py b/vllm/forward_context.py index aaa3e4bb3a1e8..cd136f43c0c57 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -1,8 +1,19 @@ +import time +from collections import Counter from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Dict, Optional +import vllm.envs as envs from vllm.config import VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + +track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 +batchsize_counter: Counter = Counter() +last_logging_time: float = 0 +batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL @dataclass @@ -26,7 +37,26 @@ def get_forward_context() -> ForwardContext: @contextmanager def set_forward_context(context: Any, vllm_config: VllmConfig): """A context manager that stores the current forward context, - can be attention metadata, etc.""" + can be attention metadata, etc. + Here we can inject common logic for every model forward pass. + """ + global track_batchsize, batchsize_counter + global last_logging_time, batchsize_logging_interval + if track_batchsize and context is not None: + if hasattr(context, "num_prefill_tokens"): + # for v0 attention backends + batchsize = context.num_prefill_tokens + context.num_decode_tokens + else: + # for v1 attention backends + batchsize = context.num_input_tokens + batchsize_counter[batchsize] += 1 + if time.monotonic() - last_logging_time > batchsize_logging_interval: + last_logging_time = time.monotonic() + sorted_data = sorted(batchsize_counter.items(), + key=lambda x: x[1], + reverse=True) + logger.info("Batchsize distribution (batchsize, count): %s", + sorted_data) global _forward_context prev_context = _forward_context _forward_context = ForwardContext( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 251a103e60f06..c9f04ace644c7 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -56,6 +56,7 @@ class FlashAttentionMetadata: seq_start_loc: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor + num_input_tokens: int = 0 # Number of tokens including padding. class FlashAttentionImpl(AttentionImpl): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0a5adfb28c9bd..a3335fa838352 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -445,6 +445,8 @@ def execute_model( # Eager mode. num_input_tokens = num_scheduled_tokens + attn_metadata.num_input_tokens = num_input_tokens + # Get the inputs embeds. if encoder_outputs: inputs_embeds = self.model.get_input_embeddings(