diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index e6ddca69bf01b..2efe142a17b69 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -65,11 +65,6 @@ def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": def get_builder_cls() -> Type["AttentionMetadataBuilder"]: raise NotImplementedError - @classmethod - def make_metadata_builder(cls, *args, - **kwargs) -> "AttentionMetadataBuilder": - return cls.get_builder_cls()(*args, **kwargs) - @staticmethod @abstractmethod def get_kv_cache_shape( @@ -214,6 +209,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]): @abstractmethod def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: + """Create the builder, remember some configuration and parameters.""" + raise NotImplementedError + + @abstractmethod + def prepare(self) -> None: + """Prepare for one batch.""" raise NotImplementedError @abstractmethod diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 40250ef08b595..60ed09d0cc44f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -375,6 +375,12 @@ class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -388,11 +394,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_decode_tokens = 0 self.has_prefix_cache_hit = False - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool, prefix_cache_hit: bool): diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b9cd805e81b45..b8ffbe6dd64dd 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -488,6 +488,14 @@ def advance_step(self, class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -500,12 +508,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 - self.input_builder = input_builder - self.runner = input_builder.runner - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # for the precise definition of the following fields. # An example: diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 534f79b3a60bf..37860494702cf 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -253,6 +253,11 @@ class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + def prepare(self): self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.curr_seq_lens: List[int] = [] @@ -263,9 +268,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 - self.input_builder = input_builder - self.runner = input_builder.runner - def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool): diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 7cd2049f0c0a5..8722d7376795a 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -282,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: self.chunked_prefill = input_builder.chunked_prefill - self.input_data = input_builder.input_data + self.input_builder = input_builder + + def prepare(self): + self.input_data = self.input_builder.input_data def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 56cc43430301f..3df7f54cbd8d2 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): _metadata_cls: Type[TAttentionMetadata] def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): self.slot_mapping: List[int] = [] self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] @@ -134,12 +141,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefill_tokens = 0 self.num_decode_tokens = 0 - self.input_builder = input_builder - self.runner = input_builder.runner - - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", chunked_prefill_enabled: bool): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index abbf6450ab7f6..4b429b67b36f8 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -144,9 +144,7 @@ def __init__(self, runner: "CPUModelRunner", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() - self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner - self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls @@ -156,10 +154,17 @@ def __init__(self, self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None + if self.runner.attn_backend is not None: + # spec decode (e.g. Medusa) does not have atten backend + attn_backend = self.runner.attn_backend + self.att_metadata_builder = attn_backend.get_builder_cls()(self) + + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.input_data = ModelInputForCPUBuilder.ModelInputData( self.runner.model_config.uses_mrope) - self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( - self) + self.att_metadata_builder.prepare() def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) @@ -431,6 +436,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): """ _model_input_cls: Type[TModelInputForCPU] _builder_cls: Type[ModelInputForCPUBuilder] + builder: ModelInputForCPUBuilder def __init__( self, @@ -477,6 +483,10 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + if hasattr(self, "_builder_cls"): + # multi-step model runner does not have `_builder_cls` + self.builder = self._builder_cls(weakref.proxy(self)) + def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) @@ -522,10 +532,10 @@ def _prepare_model_input_tensors( metadata for possible additional steps, e.g., sampling. """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) - builder.set_seq_group_list(seq_group_metadata_list) + self.builder.prepare(finished_requests_ids) + self.builder.set_seq_group_list(seq_group_metadata_list) - return builder.build() # type: ignore + return self.builder.build() # type: ignore # sampler property will be used by spec_decode_worker @property diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cb2ff0c934da3..e311c14111d49 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -457,17 +457,13 @@ def __init__(self, self.enable_prompt_adapter = (self.runner.prompt_adapter_config is not None) self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper - self.finished_requests_ids = finished_requests_ids self.decode_only = True - # Intermediate data (data in CPU before going to GPU) for - # the current sequence group. - self.inter_data_list: List[ - ModelInputForGPUBuilder.InterDataForSeqGroup] = [] - # Attention metadata inputs. - self.attn_metadata_builder = self.attn_backend.make_metadata_builder( - weakref.proxy(self)) + if self.attn_backend is not None: + # spec decode (e.g. Medusa) does not have atten backend + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self)) # Engine/Model configurations. self.chunked_prefill_enabled = ( @@ -479,6 +475,17 @@ def __init__(self, self.block_aligned_sliding_window = \ self.sliding_window_blocks * self.block_size + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.finished_requests_ids = finished_requests_ids + + # Intermediate data (data in CPU before going to GPU) for + # the current sequence group. + self.inter_data_list: List[ + ModelInputForGPUBuilder.InterDataForSeqGroup] = [] + + self.attn_metadata_builder.prepare() + def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, seq_group_metadata: SequenceGroupMetadata): """Compute context length, sequence length and tokens @@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ _model_input_cls: Type[TModelInputForGPU] _builder_cls: Type[ModelInputForGPUBuilder] + builder: ModelInputForGPUBuilder def __init__( self, @@ -1093,6 +1101,10 @@ def __init__( SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None + if hasattr(self, "_builder_cls"): + # multi-step model runner does not have `_builder_cls` + self.builder = self._builder_cls(weakref.proxy(self)) + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: @@ -1226,13 +1238,13 @@ def _prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + self.builder.prepare(finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: - builder.add_seq_group(seq_group_metadata) + self.builder.add_seq_group(seq_group_metadata) - builder.reset_cached_inter_data() + self.builder.reset_cached_inter_data() - return builder.build() # type: ignore + return self.builder.build() # type: ignore @contextmanager def set_in_profile_run(self): diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index acfd6d0b03f62..aef4bdcdd4bf9 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]): """A builder to create ModelRunnerInputBase objects. """ + @abstractmethod + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + raise NotImplementedError + @abstractmethod def add_seq_group(self, seq_group_metadata): """TBA""" diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 25a2fea1e8eac..053658d047311 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -113,7 +113,6 @@ def __init__(self, runner: "XPUModelRunner", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() - self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend @@ -121,6 +120,10 @@ def __init__(self, self.block_size = self.runner.block_size self.device = self.runner.device + def prepare(self, + finished_requests_ids: Optional[List[str]] = None) -> None: + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) @@ -408,6 +411,8 @@ def __init__( SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None + self.builder = self._builder_cls(weakref.proxy(self)) + def load_model(self) -> None: with DeviceMemoryProfiler() as m: self.model = get_model(vllm_config=self.vllm_config) @@ -517,7 +522,8 @@ def _prepare_model_input_tensors( metadata for possible additional steps, e.g., sampling. """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + builder = self.builder + builder.prepare(finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: builder.add_seq_group(seq_group_metadata)