From d559979c548c4bee6eca089d5e6dc318630bf465 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 22 Nov 2024 17:34:03 -0800 Subject: [PATCH] [bugfix] fix cpu tests (#10585) Signed-off-by: youkaichao --- vllm/worker/cpu_embedding_model_runner.py | 4 +++- vllm/worker/cpu_enc_dec_model_runner.py | 4 +++- vllm/worker/cpu_model_runner.py | 18 ++++++++++-------- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/vllm/worker/cpu_embedding_model_runner.py b/vllm/worker/cpu_embedding_model_runner.py index d0b8fec48d74f..978de73df6b70 100644 --- a/vllm/worker/cpu_embedding_model_runner.py +++ b/vllm/worker/cpu_embedding_model_runner.py @@ -3,6 +3,7 @@ import torch +from vllm.forward_context import set_forward_context from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalKwargs from vllm.pooling_params import PoolingParams @@ -64,7 +65,8 @@ def execute_model( intermediate_tensors, } - hidden_states = model_executable(**execute_model_kwargs) + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. if not self.is_driver_worker: diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index d040831870bd8..1f8e2d2d88a23 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -4,6 +4,7 @@ import torch from vllm.attention import AttentionMetadata +from vllm.forward_context import set_forward_context from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs @@ -303,7 +304,8 @@ def execute_model( intermediate_tensors, } - hidden_states = model_executable(**execute_model_kwargs) + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 66bd844c94901..2cf573625401a 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -10,6 +10,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -487,14 +488,15 @@ def execute_model( multimodal_kwargs = MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs, device=self.device) - hidden_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **multimodal_kwargs, - ) + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **multimodal_kwargs, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states,