From 9cc018ae38553ede15bd3646679916d05147d599 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Wed, 27 Nov 2024 04:26:27 -0800 Subject: [PATCH] [V1] Update interface for mistral-format Pixtral (#10703) Signed-off-by: Roger Wang Signed-off-by: Andrew Feldman --- vllm/model_executor/models/pixtral.py | 47 ++++++++++++++++----------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 6711cbf5694b9..45171c1a04b17 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -31,7 +31,7 @@ from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import PlaceholderRange +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges, resolve_visual_encoder_outputs) @@ -190,6 +190,25 @@ def sampler(self): return get_sampler() + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.vision_args.image_token_id) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -197,31 +216,21 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for pixtral. - - TODO - """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.vision_args.image_token_id) - - input_ids = None - else: - inputs_embeds = None + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions,