diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 848eea7f54cab..8011398551b9d 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -909,6 +909,7 @@ def make_test_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -958,6 +959,7 @@ def make_test_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 309854e6babf3..57f1fd47a600f 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -74,6 +74,7 @@ def test_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), @@ -126,6 +127,7 @@ def test_embedding_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=torch.ones(10), @@ -177,6 +179,7 @@ def test_multi_step_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) frozen_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 4e58e6db91a4c..796017c1c2b21 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, fields from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) @@ -126,8 +126,7 @@ class AttentionMetadata: # Enable/disable KV scales calculation. This is so that we can disable the # calculation until after prefill and cuda graph capture. - enable_kv_scales_calculation: bool = field(init=False, - default_factory=lambda: True) + enable_kv_scales_calculation: bool @property @abstractmethod diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 0bef046e16806..aa7cc981372a0 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -222,6 +222,7 @@ def prefill_metadata( slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -251,6 +252,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c2af8f5216e08..c7850009a954c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -224,6 +224,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -268,6 +269,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, @@ -550,6 +552,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index f1d1cd3303734..d9d07e6be173e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -218,6 +218,7 @@ def graph_capture_get_metadata_for_batch( num_prefills=0, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, num_prefill_tokens=0, num_decode_tokens=batch_size, max_prefill_seq_len=0, @@ -711,6 +712,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 534f79b3a60bf..d2dc0d6cf0a5f 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -140,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_decode_query_len=0, @@ -173,6 +174,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_decode_query_len=self.max_decode_query_len, @@ -378,6 +380,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e9ef8e52aa2fd..bbfcad0a3d626 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -152,6 +152,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -181,6 +182,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index cb2f4d5adeaaf..6f4b2ffb2371c 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -372,6 +372,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], prefill_block_tables=prefill_block_tables, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, ) return attn_metadata diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 56cc43430301f..8ceeaf48bb19d 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -264,6 +264,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -316,6 +317,7 @@ def graph_capture_get_metadata_for_batch( num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=1, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 29d4571e05838..88b3f56e0848e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -215,6 +215,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -259,6 +260,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 9d479f412af46..9d49c6d0a0983 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -892,7 +892,8 @@ def _prepare_prompt( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps= - None # FIXME(kzawora): mutli-modality will not work here + None, # FIXME(kzawora): mutli-modality will not work here + enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) @@ -1046,7 +1047,9 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6fb3ca1936b39..4155f3f275083 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -176,8 +176,6 @@ def from_broadcasted_tensor_dict( if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) - if "enable_kv_scales_calculation" in tensor_dict: - tensor_dict.pop("enable_kv_scales_calculation") return cls(**tensor_dict) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 20b91ead08285..c7abad7e0258d 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -46,8 +46,7 @@ def _init_attn_metadata_from_tensor_dict( # Extract the fields used to create AttentionMetadata. valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - if field.name in tensor_dict and field.name != \ - 'enable_kv_scales_calculation': + if field.name in tensor_dict: valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 6000e5dfe4e30..23a48e29731a7 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -278,6 +278,7 @@ def _prepare_model_input( block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 7bdb7f0e2d6a9..4a9c06b1d4763 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -186,6 +186,7 @@ def _dummy_run( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=None, context_lens=None, effective_query_lens=None, @@ -204,6 +205,7 @@ def _dummy_run( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, effective_query_lens=effective_query_lens, @@ -235,6 +237,7 @@ def _dummy_run( num_decode_tokens=batch_size * seq_len, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, ) @@ -420,6 +423,7 @@ def _prepare_prompt( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, effective_query_lens=prompt_lens, @@ -491,6 +495,7 @@ def _prepare_decode( num_decode_tokens=batch_size, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, ) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 9cf25387560da..a4a03050a46c3 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -257,6 +257,7 @@ def _prepare_prompt( is_prompt=True, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, @@ -341,6 +342,7 @@ def _prepare_decode( is_prompt=False, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, seq_lens=seq_lens, seqlen_q=torch.tensor([]), max_seqlen=0,