Skip to content

Commit

Permalink
[core] separate builder init and builder prepare for each batch (#12253)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Jan 22, 2025
1 parent 222a9dc commit 66818e5
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 47 deletions.
11 changes: 6 additions & 5 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError

@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_kv_cache_shape(
Expand Down Expand Up @@ -214,6 +209,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):

@abstractmethod
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError

@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError

@abstractmethod
Expand Down
11 changes: 6 additions & 5 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,12 @@ class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
Expand All @@ -388,11 +394,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False

self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
Expand Down
14 changes: 8 additions & 6 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,14 @@ def advance_step(self,
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):

self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
Expand All @@ -500,12 +508,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
Expand Down
8 changes: 5 additions & 3 deletions vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,11 @@ class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):

self.input_builder = input_builder
self.runner = input_builder.runner

def prepare(self):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
Expand All @@ -263,9 +268,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.input_builder = input_builder
self.runner = input_builder.runner

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
Expand Down
5 changes: 4 additions & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):

def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_data = input_builder.input_data
self.input_builder = input_builder

def prepare(self):
self.input_data = self.input_builder.input_data

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
Expand Down
13 changes: 7 additions & 6 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
_metadata_cls: Type[TAttentionMetadata]

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
Expand All @@ -134,12 +141,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.input_builder = input_builder
self.runner = input_builder.runner

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
Expand Down
24 changes: 17 additions & 7 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ def __init__(self,
runner: "CPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner

self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls
Expand All @@ -156,10 +154,17 @@ def __init__(self,
self.device = self.runner.device
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.enable_lora = self.runner.lora_config is not None
if self.runner.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
attn_backend = self.runner.attn_backend
self.att_metadata_builder = attn_backend.get_builder_cls()(self)

def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope)
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
self)
self.att_metadata_builder.prepare()

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
Expand Down Expand Up @@ -431,6 +436,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
"""
_model_input_cls: Type[TModelInputForCPU]
_builder_cls: Type[ModelInputForCPUBuilder]
builder: ModelInputForCPUBuilder

def __init__(
self,
Expand Down Expand Up @@ -477,6 +483,10 @@ def __init__(
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None

if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))

def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config)

Expand Down Expand Up @@ -522,10 +532,10 @@ def _prepare_model_input_tensors(
metadata for possible additional steps, e.g., sampling.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
builder.set_seq_group_list(seq_group_metadata_list)
self.builder.prepare(finished_requests_ids)
self.builder.set_seq_group_list(seq_group_metadata_list)

return builder.build() # type: ignore
return self.builder.build() # type: ignore

# sampler property will be used by spec_decode_worker
@property
Expand Down
36 changes: 24 additions & 12 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,13 @@ def __init__(self,
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
is not None)
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.finished_requests_ids = finished_requests_ids
self.decode_only = True

# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []

# Attention metadata inputs.
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
weakref.proxy(self))
if self.attn_backend is not None:
# spec decode (e.g. Medusa) does not have atten backend
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))

# Engine/Model configurations.
self.chunked_prefill_enabled = (
Expand All @@ -479,6 +475,17 @@ def __init__(self,
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size

def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.finished_requests_ids = finished_requests_ids

# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self.inter_data_list: List[
ModelInputForGPUBuilder.InterDataForSeqGroup] = []

self.attn_metadata_builder.prepare()

def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata):
"""Compute context length, sequence length and tokens
Expand Down Expand Up @@ -993,6 +1000,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
_model_input_cls: Type[TModelInputForGPU]
_builder_cls: Type[ModelInputForGPUBuilder]
builder: ModelInputForGPUBuilder

def __init__(
self,
Expand Down Expand Up @@ -1093,6 +1101,10 @@ def __init__(
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None

if hasattr(self, "_builder_cls"):
# multi-step model runner does not have `_builder_cls`
self.builder = self._builder_cls(weakref.proxy(self))

def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
Expand Down Expand Up @@ -1226,13 +1238,13 @@ def _prepare_model_input_tensors(
If cuda graph is required, this API automatically pads inputs.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
self.builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
self.builder.add_seq_group(seq_group_metadata)

builder.reset_cached_inter_data()
self.builder.reset_cached_inter_data()

return builder.build() # type: ignore
return self.builder.build() # type: ignore

@contextmanager
def set_in_profile_run(self):
Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""

@abstractmethod
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
raise NotImplementedError

@abstractmethod
def add_seq_group(self, seq_group_metadata):
"""TBA"""
Expand Down
10 changes: 8 additions & 2 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,17 @@ def __init__(self,
runner: "XPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device

def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)

Expand Down Expand Up @@ -408,6 +411,8 @@ def __init__(
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None

self.builder = self._builder_cls(weakref.proxy(self))

def load_model(self) -> None:
with DeviceMemoryProfiler() as m:
self.model = get_model(vllm_config=self.vllm_config)
Expand Down Expand Up @@ -517,7 +522,8 @@ def _prepare_model_input_tensors(
metadata for possible additional steps, e.g., sampling.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
builder = self.builder
builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)

Expand Down

0 comments on commit 66818e5

Please sign in to comment.