Skip to content

Commit

Permalink
Properly initializing the new field in the attn metadata (#337)
Browse files Browse the repository at this point in the history
Signed-off-by: Gregory Shtrasberg <[email protected]>
  • Loading branch information
gshtras committed Jan 7, 2025
1 parent 64668c6 commit 3eaca59
Show file tree
Hide file tree
Showing 17 changed files with 38 additions and 9 deletions.
2 changes: 2 additions & 0 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ def make_test_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down Expand Up @@ -958,6 +959,7 @@ def make_test_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down
3 changes: 3 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_model_runner_input():
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
Expand Down Expand Up @@ -126,6 +127,7 @@ def test_embedding_model_runner_input():
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
Expand Down Expand Up @@ -177,6 +179,7 @@ def test_multi_step_model_runner_input():
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
Expand Down
5 changes: 2 additions & 3 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)

Expand Down Expand Up @@ -126,8 +126,7 @@ class AttentionMetadata:

# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool = field(init=False,
default_factory=lambda: True)
enable_kv_scales_calculation: bool

@property
@abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def prefill_metadata(
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -251,6 +252,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -268,6 +269,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
Expand Down Expand Up @@ -550,6 +552,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def graph_capture_get_metadata_for_batch(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
Expand Down Expand Up @@ -711,6 +712,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len,
Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_decode_query_len=0,
Expand Down Expand Up @@ -173,6 +174,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len,
Expand Down Expand Up @@ -378,6 +380,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_prefills=self.num_prefills,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -181,6 +182,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
)

return attn_metadata
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
Expand Down Expand Up @@ -316,6 +317,7 @@ def graph_capture_get_metadata_for_batch(
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]:
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -259,6 +260,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,8 @@ def _prepare_prompt(
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here
None, # FIXME(kzawora): mutli-modality will not work here
enable_kv_scales_calculation=False,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)

Expand Down Expand Up @@ -1046,7 +1047,9 @@ def _prepare_decode(
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None)
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)
return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
Expand Down
2 changes: 0 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,6 @@ def from_broadcasted_tensor_dict(
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
if "enable_kv_scales_calculation" in tensor_dict:
tensor_dict.pop("enable_kv_scales_calculation")
return cls(**tensor_dict)


Expand Down
3 changes: 1 addition & 2 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def _init_attn_metadata_from_tensor_dict(
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
if field.name in tensor_dict and field.name != \
'enable_kv_scales_calculation':
if field.name in tensor_dict:
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)

attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
Expand Down
1 change: 1 addition & 0 deletions vllm/worker/openvino_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def _prepare_model_input(
block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
)

multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def _dummy_run(
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=None,
context_lens=None,
effective_query_lens=None,
Expand All @@ -204,6 +205,7 @@ def _dummy_run(
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
Expand Down Expand Up @@ -235,6 +237,7 @@ def _dummy_run(
num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
Expand Down Expand Up @@ -420,6 +423,7 @@ def _prepare_prompt(
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=prompt_lens,
Expand Down Expand Up @@ -491,6 +495,7 @@ def _prepare_decode(
num_decode_tokens=batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
Expand Down
2 changes: 2 additions & 0 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def _prepare_prompt(
is_prompt=True,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=seqlen_q,
max_seqlen=max_seqlen,
Expand Down Expand Up @@ -341,6 +342,7 @@ def _prepare_decode(
is_prompt=False,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=torch.tensor([]),
max_seqlen=0,
Expand Down

0 comments on commit 3eaca59

Please sign in to comment.