From 5108119b191af2a2d81a73fdfd7d706a4177152b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 5 Nov 2024 17:31:18 +0000 Subject: [PATCH 01/16] Initial prototype for multi-modal processor Signed-off-by: DarkLight1337 --- .../dev/multimodal/multimodal_index.rst | 2 +- .../models/enabling_multimodal_inputs.rst | 2 +- .../mm_processor_kwargs/test_qwen.py | 4 +- tests/multimodal/test_base.py | 22 +- vllm/config.py | 2 +- vllm/engine/async_llm_engine.py | 4 + vllm/engine/llm_engine.py | 5 +- vllm/engine/multiprocessing/client.py | 6 + vllm/engine/protocol.py | 16 +- vllm/entrypoints/openai/serving_chat.py | 1 - vllm/entrypoints/openai/serving_completion.py | 1 - vllm/inputs/__init__.py | 4 +- vllm/inputs/data.py | 13 +- vllm/inputs/preprocess.py | 130 +++++++-- vllm/inputs/registry.py | 17 +- vllm/model_executor/models/chatglm.py | 4 +- vllm/model_executor/models/fuyu.py | 4 +- vllm/model_executor/models/h2ovl.py | 10 +- vllm/model_executor/models/internvl.py | 6 +- vllm/model_executor/models/minicpmv.py | 4 +- vllm/model_executor/models/mllama.py | 2 +- vllm/model_executor/models/molmo.py | 4 +- vllm/model_executor/models/pixtral.py | 10 +- vllm/model_executor/models/qwen.py | 12 +- vllm/model_executor/models/qwen2_audio.py | 8 +- vllm/model_executor/models/qwen2_vl.py | 8 +- vllm/model_executor/models/ultravox.py | 8 +- vllm/multimodal/__init__.py | 10 +- vllm/multimodal/audio.py | 4 +- vllm/multimodal/base.py | 187 +----------- vllm/multimodal/image.py | 10 +- vllm/multimodal/inputs.py | 229 +++++++++++++++ vllm/multimodal/processing.py | 275 ++++++++++++++++++ vllm/multimodal/registry.py | 86 +++++- vllm/multimodal/video.py | 6 +- vllm/sequence.py | 37 ++- vllm/spec_decode/draft_model_runner.py | 4 +- vllm/v1/engine/llm_engine.py | 15 +- vllm/worker/cpu_enc_dec_model_runner.py | 4 +- vllm/worker/cpu_model_runner.py | 16 +- vllm/worker/embedding_model_runner.py | 4 +- vllm/worker/enc_dec_model_runner.py | 4 +- vllm/worker/model_runner.py | 27 +- vllm/worker/neuron_model_runner.py | 24 +- vllm/worker/openvino_model_runner.py | 23 +- vllm/worker/xpu_model_runner.py | 20 +- 46 files changed, 943 insertions(+), 351 deletions(-) create mode 100644 vllm/multimodal/inputs.py create mode 100644 vllm/multimodal/processing.py diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index e112b43aade5e..30f543abc20c7 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -53,7 +53,7 @@ Base Classes .. autodata:: vllm.multimodal.MultiModalDataDict -.. autoclass:: vllm.multimodal.MultiModalInputs +.. autoclass:: vllm.multimodal.MultiModalKwargs :members: :show-inheritance: diff --git a/docs/source/models/enabling_multimodal_inputs.rst b/docs/source/models/enabling_multimodal_inputs.rst index 3d0d1aec69845..49b5285c45590 100644 --- a/docs/source/models/enabling_multimodal_inputs.rst +++ b/docs/source/models/enabling_multimodal_inputs.rst @@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i 3. Register maximum number of multi-modal tokens ------------------------------------------------ -For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance +For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item and register it via :meth:`INPUT_REGISTRY.register_dummy_data `. .. code-block:: diff diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py index a01651b171d60..9f2daa4c7273f 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py @@ -6,7 +6,7 @@ from PIL.Image import Image from vllm.inputs import InputContext, token_inputs -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from .....conftest import IMAGE_ASSETS @@ -96,7 +96,7 @@ def test_input_mapper_valid_mm_data(input_mapper_for_qwen, mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data) # Ensure that we get the appropriately shaped pixel_values # for images and image embeddings, respectively. - assert isinstance(mapped_img_data, MultiModalInputs) + assert isinstance(mapped_img_data, MultiModalKwargs) assert "pixel_values" in mapped_img_data assert mapped_img_data["pixel_values"].shape == expected_shape diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py index 68d05de904ba8..bfaf2cdeaa8d4 100644 --- a/tests/multimodal/test_base.py +++ b/tests/multimodal/test_base.py @@ -1,6 +1,6 @@ import torch -from vllm.multimodal.base import MultiModalInputs, NestedTensors +from vllm.multimodal.base import MultiModalKwargs, NestedTensors def assert_nested_tensors_equal(expected: NestedTensors, @@ -13,8 +13,8 @@ def assert_nested_tensors_equal(expected: NestedTensors, assert_nested_tensors_equal(expected_item, actual_item) -def assert_multimodal_inputs_equal(expected: MultiModalInputs, - actual: MultiModalInputs): +def assert_multimodal_inputs_equal(expected: MultiModalKwargs, + actual: MultiModalKwargs): assert set(expected.keys()) == set(actual.keys()) for key in expected: assert_nested_tensors_equal(expected[key], actual[key]) @@ -22,7 +22,7 @@ def assert_multimodal_inputs_equal(expected: MultiModalInputs, def test_multimodal_input_batch_single_tensor(): t = torch.rand([1, 2]) - result = MultiModalInputs.batch([{"image": t}]) + result = MultiModalKwargs.batch([{"image": t}]) assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)}) @@ -30,7 +30,7 @@ def test_multimodal_input_batch_multiple_tensors(): a = torch.rand([1, 1, 2]) b = torch.rand([1, 1, 2]) c = torch.rand([1, 1, 2]) - result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) + result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}]) assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])}) @@ -38,7 +38,7 @@ def test_multimodal_input_batch_multiple_heterogeneous_tensors(): a = torch.rand([1, 2, 2]) b = torch.rand([1, 3, 2]) c = torch.rand([1, 4, 2]) - result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) + result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}]) assert_multimodal_inputs_equal(result, {"image": [a, b, c]}) @@ -46,7 +46,7 @@ def test_multimodal_input_batch_nested_tensors(): a = torch.rand([2, 3]) b = torch.rand([2, 3]) c = torch.rand([2, 3]) - result = MultiModalInputs.batch([{ + result = MultiModalKwargs.batch([{ "image": [a] }, { "image": [b] @@ -65,7 +65,7 @@ def test_multimodal_input_batch_heterogeneous_lists(): a = torch.rand([1, 2, 3]) b = torch.rand([1, 2, 3]) c = torch.rand([1, 2, 3]) - result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) + result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}]) assert_multimodal_inputs_equal( result, {"image": [torch.stack([a, b]), c.unsqueeze(0)]}) @@ -76,7 +76,7 @@ def test_multimodal_input_batch_multiple_batchable_lists(): b = torch.rand([1, 2, 3]) c = torch.rand([1, 2, 3]) d = torch.rand([1, 2, 3]) - result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}]) + result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}]) assert_multimodal_inputs_equal( result, {"image": torch.stack([torch.stack([a, b]), @@ -88,8 +88,8 @@ def test_multimodal_input_batch_mixed_stacking_depths(): b = torch.rand([1, 3, 3]) c = torch.rand([1, 4, 3]) - result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) + result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}]) assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]}) - result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}]) + result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b, c]}]) assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]}) diff --git a/vllm/config.py b/vllm/config.py index 814e00c8785f0..86127e23e2d3c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -107,7 +107,7 @@ class ModelConfig: matches the model name exposed via the APIs. If multiple model names provided, the first name will be used. If not specified, the model name will be the same as `model`. - limit_mm_per_prompt: Maximum number of data instances per modality + limit_mm_per_prompt: Maximum number of data items per modality per prompt. Only applicable for multimodal models. override_neuron_config: Initialize non default neuron config or override default neuron config that are specific to Neuron devices, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index b0fdc67776bbd..d0ad253489827 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -19,6 +19,7 @@ from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( @@ -721,6 +722,9 @@ def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.engine.input_preprocessor + async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a1809b1a9dd26..3cf4e582a245a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -39,6 +39,7 @@ from vllm.model_executor.guided_decoding import ( get_local_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams @@ -226,6 +227,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, ) -> None: @@ -338,7 +340,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: model_config) self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) + self.tokenizer, + mm_registry) self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 882742c2fc61b..fe21c58c775fe 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -31,6 +31,7 @@ # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -94,6 +95,8 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, parallel_config=engine_config.parallel_config, enable_lora=bool(engine_config.lora_config), ) + self.input_preprocessor = InputPreprocessor(self.model_config, + self.tokenizer) # Send RPCGenerateRequest to the MQLLMEngine. self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) @@ -345,6 +348,9 @@ async def _check_success(error_message: str, socket: Socket): or response != VLLM_RPC_SUCCESS_STR): raise ValueError(error_message) + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.input_preprocessor + async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): return await self.tokenizer.get_lora_tokenizer_async(lora_request) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e0b59d94cfdc3..e15395d75c91f 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -62,7 +62,6 @@ def generate( async def beam_search( self, prompt: PromptType, - model_config: ModelConfig, request_id: str, params: BeamSearchParams, ) -> AsyncGenerator[RequestOutput, None]: @@ -74,13 +73,14 @@ async def beam_search( length_penalty = params.length_penalty include_stop_str_in_output = params.include_stop_str_in_output - tokenizer = await self.get_tokenizer() - input_preprocessor = InputPreprocessor(model_config, tokenizer) + preprocessor = await self.get_input_preprocessor() + tokenizer_group = preprocessor.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async() if is_explicit_encoder_decoder_prompt(prompt): raise NotImplementedError else: - processed_inputs = input_preprocessor._prompt_to_llm_inputs( + processed_inputs = preprocessor._prompt_to_llm_inputs( prompt, request_id=request_id, ) @@ -220,6 +220,7 @@ async def abort(self, request_id: str) -> None: Args: request_id: The unique id of the request. """ + ... @abstractmethod async def get_model_config(self) -> ModelConfig: @@ -228,8 +229,13 @@ async def get_model_config(self) -> ModelConfig: @abstractmethod async def get_decoding_config(self) -> DecodingConfig: - ... """Get the decoding configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_input_preprocessor(self) -> InputPreprocessor: + """Get the input processor of the vLLM engine.""" + ... @abstractmethod async def get_tokenizer( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9551b4f2091dd..359578fd9f7e4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -187,7 +187,6 @@ async def create_chat_completion( if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, - model_config=self.model_config, request_id=request_id, params=sampling_params, ) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 570232be38379..5a9da5e0c47b4 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -140,7 +140,6 @@ async def create_completion( if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, - model_config=self.model_config, request_id=request_id, params=sampling_params, ) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 68ac50a2c5a16..338589ed04dc4 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -3,7 +3,8 @@ SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import DummyData, InputContext, InputRegistry +from .registry import (DummyData, InputContext, InputProcessingContext, + InputRegistry) INPUT_REGISTRY = InputRegistry() """ @@ -32,6 +33,7 @@ "INPUT_REGISTRY", "DummyData", "InputContext", + "InputProcessingContext", "InputRegistry", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 46b41f431bec7..ac1a425538c34 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict + from vllm.multimodal.inputs import MultiModalInputsV2 class TextPrompt(TypedDict): @@ -36,13 +37,13 @@ class TokensPrompt(TypedDict): multi_modal_data: NotRequired["MultiModalDataDict"] """ - Optional multi-modal data to pass to the model, + DEPRECATED: Optional multi-modal data to pass to the model, if the model supports it. """ mm_processor_kwargs: NotRequired[Dict[str, Any]] """ - Optional multi-modal processor kwargs to be forwarded to the + DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities have registered mappers etc for the model being considered, we attempt to pass the mm_processor_kwargs to each of them. @@ -176,7 +177,7 @@ def token_inputs( return inputs -DecoderOnlyInputs = TokenInputs +DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"] """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. @@ -191,14 +192,14 @@ class EncoderDecoderInputs(TypedDict): This specifies the required data for encoder-decoder models. """ - encoder: TokenInputs + encoder: Union[TokenInputs, "MultiModalInputsV2"] """The inputs for the encoder portion.""" - decoder: TokenInputs + decoder: Union[TokenInputs, "MultiModalInputsV2"] """The inputs for the decoder portion.""" -SingletonInputs = TokenInputs +SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"] """ A processed :class:`SingletonPrompt` which can be passed to :class:`vllm.sequence.Sequence`. diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a5c787a56b5a9..3fb7865d22a34 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,11 +1,13 @@ import asyncio -from typing import List, Optional +from typing import List, Mapping, Optional, Union from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2 from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.utils import print_warning_once @@ -23,11 +25,13 @@ def __init__( self, model_config: ModelConfig, tokenizer: Optional[BaseTokenizerGroup], + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ) -> None: super().__init__() self.model_config = model_config self.tokenizer = tokenizer + self.mm_registry = mm_registry def get_tokenizer_group(self) -> BaseTokenizerGroup: if self.tokenizer is None: @@ -198,14 +202,66 @@ async def _tokenize_prompt_async( prompt=prompt, lora_request=lora_request) + def _can_process_multimodal(self) -> bool: + # Interim measure so we can handle models that have yet to be + # updated to use the new multi-modal processor + return self.mm_registry.has_processor(self.model_config) + + def _process_multimodal( + self, + prompt: Union[str, List[int]], + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Mapping[str, object]], + lora_request: Optional[LoRARequest], + ) -> MultiModalInputsV2: + """ + Apply the model's multi-modal processor to a multi-modal prompt, + returning the corresponding token IDs and metadata. + """ + tokenizer_group = self.get_tokenizer_group() + tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) + + mm_processor = self.mm_registry.create_processor( + self.model_config, tokenizer) + + if isinstance(prompt, list): + prompt = tokenizer.decode(prompt) + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + return mm_processor.apply(prompt, mm_data, mm_processor_kwargs) + + async def _process_multimodal_async( + self, + prompt: Union[str, List[int]], + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Mapping[str, object]], + lora_request: Optional[LoRARequest], + ) -> MultiModalInputsV2: + """Async version of :meth:`_process_multimodal`.""" + tokenizer_group = self.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request + ) + + mm_processor = self.mm_registry.create_processor( + self.model_config, tokenizer) + if isinstance(prompt, list): + logger.warning("Passing `multi_modal_data` in TokensPrompt is" + "deprecated and will be removed in a future update") + prompt = tokenizer.decode(prompt) + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + return mm_processor.apply(prompt, mm_data, mm_processor_kwargs) + def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> SingletonInputs: - ''' - Extract the components of any single encoder or decoder input prompt. + """ + Extract the singleton inputs from a prompt. Arguments: @@ -215,12 +271,8 @@ def _prompt_to_llm_inputs( Returns: - * prompt - * prompt_token_ids - * multi_modal_data - * mm_processor_kwargs (request-level input processor/mapper overrides) - ''' - + * :class:`SingletonInputs` instance + """ parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": @@ -243,6 +295,14 @@ def _prompt_to_llm_inputs( multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + return token_inputs( prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data, @@ -253,13 +313,22 @@ def _prompt_to_llm_inputs( text_content = parsed["content"] prompt_text = text_content["prompt"] + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_text, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + prompt_token_ids = self._tokenize_prompt( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = text_content.get("multi_modal_data") - mm_processor_kwargs = text_content.get("mm_processor_kwargs") return token_inputs( prompt=prompt_text, @@ -299,6 +368,14 @@ async def _prompt_to_llm_inputs_async( multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + if multi_modal_data is not None and self._can_process_multimodal(): + return await self._process_multimodal_async( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + return token_inputs( prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data, @@ -309,13 +386,22 @@ async def _prompt_to_llm_inputs_async( text_content = parsed["content"] prompt_text = text_content["prompt"] + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return await self._process_multimodal_async( + prompt_text, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + prompt_token_ids = await self._tokenize_prompt_async( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = text_content.get("multi_modal_data") - mm_processor_kwargs = text_content.get("mm_processor_kwargs") return token_inputs( prompt=prompt_text, @@ -331,7 +417,8 @@ def _build_enc_dec_llm_inputs( encoder_inputs: SingletonInputs, decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: - if encoder_inputs["type"] == "token": + if (encoder_inputs["type"] == "token" + or encoder_inputs["type"] == "multimodal"): pass else: assert_never(encoder_inputs) @@ -340,7 +427,8 @@ def _build_enc_dec_llm_inputs( dec_token_ids = self._prepare_decoder_input_ids_for_generation( None) decoder_inputs = token_inputs(dec_token_ids) - elif decoder_inputs["type"] == "token": + elif (decoder_inputs["type"] == "token" + or decoder_inputs["type"] == "multimodal"): dec_token_ids = self._prepare_decoder_input_ids_for_generation( decoder_inputs["prompt_token_ids"]) decoder_inputs["prompt_token_ids"] = dec_token_ids @@ -361,7 +449,7 @@ def _process_encoder_decoder_prompt( prompt: PromptType, request_id: str, ) -> EncoderDecoderInputs: - ''' + """ For encoder/decoder models only: Process an input prompt into an :class:`EncoderDecoderInputs` instance. @@ -391,8 +479,7 @@ def _process_encoder_decoder_prompt( Returns: * :class:`EncoderDecoderInputs` instance - ''' - + """ encoder_inputs: SingletonInputs decoder_inputs: Optional[SingletonInputs] @@ -460,7 +547,8 @@ def _build_decoder_only_llm_inputs( prompt_inputs: DecoderOnlyInputs, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: - if prompt_inputs["type"] == "token": + if (prompt_inputs["type"] == "token" + or prompt_inputs["type"] == "multimodal"): prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"], prompt_adapter_request=prompt_adapter_request, @@ -477,7 +565,7 @@ def _process_decoder_only_prompt( lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> DecoderOnlyInputs: - ''' + """ For decoder-only models: Process an input prompt into an :class:`DecoderOnlyInputs` instance. @@ -491,7 +579,7 @@ def _process_decoder_only_prompt( Returns: * :class:`DecoderOnlyInputs` instance - ''' + """ prompt_comps = self._prompt_to_llm_inputs( prompt, diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 7d7a797be4f60..47edcfd8ed1c8 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -5,10 +5,12 @@ Optional, Protocol, Type, cast) from torch import nn -from transformers import PretrainedConfig +from transformers import PretrainedConfig, ProcessorMixin from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.transformers_utils.processor import cached_get_processor +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, resolve_mm_processor_kwargs) @@ -61,6 +63,19 @@ def get_hf_image_processor_config(self) -> Dict[str, Any]: return self.model_config.hf_image_processor_config +@dataclass(frozen=True) +class InputProcessingContext(InputContext): + tokenizer: AnyTokenizer + """The tokenizer used to tokenize the inputs.""" + + def get_hf_processor(self) -> ProcessorMixin: + return cached_get_processor( + self.model_config.tokenizer, + tokenizer=self.tokenizer, # Override the tokenizer with ours + trust_remote_code=self.model_config.trust_remote_code, + ) + + N = TypeVar("N", bound=Type[nn.Module]) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index c3c9ec703c1e6..d0c5486ca5ac5 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.base import MultiModalData from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -74,7 +74,7 @@ def mm_input_mapper_for_glmv( raise pixel_values = raw_batch_data['images'] - return MultiModalInputs({'pixel_values': pixel_values}) + return MultiModalKwargs({'pixel_values': pixel_values}) def merge_glm_vision_embeddings( diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 0de590d1d8372..1975d8a570f9a 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -35,7 +35,7 @@ from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges) @@ -219,7 +219,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): ]) # image has been processed with prompt in input processor - return MultiModalInputs({"pixel_values": data}) + return MultiModalKwargs({"pixel_values": data}) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu) diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 43242fe370ba2..767171dad7c7b 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -16,7 +16,7 @@ token_inputs) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.utils import is_list_of @@ -324,12 +324,12 @@ def input_mapper( data: object, *, max_dynamic_patch: Optional[int] = None, - ) -> MultiModalInputs: + ) -> MultiModalKwargs: # NOTE: Preprocessing for the image data is done in the # 'input_processor' function during actual inference. if isinstance(data, dict): - return MultiModalInputs(data) + return MultiModalKwargs(data) # The section below is only used with dummy data during # memory profiling. @@ -347,7 +347,7 @@ def input_mapper( pixel_values = [image_pixel_values_mapper(img) for img in data] else: - return MultiModalInputs({"image_embeds": data}) + return MultiModalKwargs({"image_embeds": data}) model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, @@ -359,7 +359,7 @@ def input_mapper( return_tensors="pt", )[0] - return MultiModalInputs({ + return MultiModalKwargs({ "pixel_values": pixel_values, "image_token_id": image_token_id }) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d2ec0ff6e74c6..dfa869c1f5531 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -26,7 +26,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -346,7 +346,7 @@ def input_mapper( # we can't stack here because images may have different num_patches data = [image_pixel_values_mapper(img) for img in data] else: - return MultiModalInputs({"image_embeds": data}) + return MultiModalKwargs({"image_embeds": data}) model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, @@ -355,7 +355,7 @@ def input_mapper( add_special_tokens=False, return_tensors="pt")[0] - return MultiModalInputs({ + return MultiModalKwargs({ "pixel_values": data, "image_token_id": image_token_id }) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index f90df6b7df036..6ad04631f208e 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -53,7 +53,7 @@ from vllm.model_executor.models.utils import LLMWrapper from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData @@ -375,7 +375,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): batch_data["slice_start_id"] = data[0]["slice_start_id"] batch_data["slice_end_id"] = data[0]["slice_end_id"] - return MultiModalInputs(batch_data) + return MultiModalKwargs(batch_data) class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 251bfc079684e..f7d2889a6fa5b 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1163,7 +1163,7 @@ def sample( def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by - # MultiModalInputs.batch, so pixel_values here can be: + # MultiModalKwargs.batch, so pixel_values here can be: # - List[List[torch.Tensor]]: # with shape (num_tiles, 3, image_res, image_res) # - List[torch.Tensor]: diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index ba798833e26a9..3aa68893de2d1 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) @@ -865,7 +865,7 @@ def image_input_mapper_for_molmo( ctx: InputContext, data: object, ): - return MultiModalInputs(data) + return MultiModalKwargs(data) def dummy_data_for_molmo(ctx: InputContext, seq_len: int, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 051454c49bff8..2e2683646ffdf 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -27,7 +27,7 @@ from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges) from vllm.sequence import IntermediateTensors, SequenceData @@ -91,8 +91,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, def input_mapper_for_pixtral(ctx: InputContext, - data: object) -> MultiModalInputs: - """Maps the input data to its MultiModalInputs (if any). + data: object) -> MultiModalKwargs: + """Maps the input data to its MultiModalKwargs (if any). Args: ctx: Context of the loaded model. @@ -100,7 +100,7 @@ def input_mapper_for_pixtral(ctx: InputContext, to pixel_values in .forward() for a visual QWenLMHeadModel model. Returns: - MultiModalInputs containing the stacked normalized images tensor or + MultiModalKwargs containing the stacked normalized images tensor or image embeddings. """ # Early exit if we have provided an image to a language only Qwen model @@ -118,7 +118,7 @@ def input_mapper_for_pixtral(ctx: InputContext, dtype=torch.float16) images.append(image) - return MultiModalInputs({"images": images}) + return MultiModalKwargs({"images": images}) def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index b2b5c70182135..a53ddf12c8f5d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -44,7 +44,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of @@ -723,8 +723,8 @@ def input_processor_for_qwen(ctx: InputContext, multi_modal_data=multi_modal_data) -def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: - """Maps the input data to its MultiModalInputs (if any). +def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalKwargs: + """Maps the input data to its MultiModalKwargs (if any). Args: ctx: Context of the loaded model. @@ -732,7 +732,7 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: to pixel_values in .forward() for a visual QWenLMHeadModel model. Returns: - MultiModalInputs containing the stacked normalized images tensor or + MultiModalKwargs containing the stacked normalized images tensor or image embeddings. """ # Early exit if we have provided an image to a language only Qwen model @@ -741,7 +741,7 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: logger.warning( "Images were provided but this model has no visual config; " "multimodal inputs will not be forwarded to the model.") - return MultiModalInputs() + return MultiModalKwargs() model_config = ctx.model_config tokenizer = cached_get_tokenizer( @@ -785,7 +785,7 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: data = [data] transformed_images = [transform(datum) for datum in data] pixel_values = torch.stack(transformed_images, dim=0) - return MultiModalInputs({"pixel_values": pixel_values}) + return MultiModalKwargs({"pixel_values": pixel_values}) def build_normalization_transform(image_size: int) -> transforms.Compose: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 6114548bda42c..c6edcf0071118 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -43,7 +43,7 @@ default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData @@ -222,13 +222,13 @@ def input_processor_for_qwen2_audio( def input_mapper_for_qwen2_audio( ctx: InputContext, multi_modal_data: Union[np.ndarray, List[np.ndarray]], -) -> MultiModalInputs: +) -> MultiModalKwargs: """Input mapper for Qwen2-Audio.""" if not isinstance(multi_modal_data, list): multi_modal_data = [multi_modal_data] if len(multi_modal_data) == 0: - return MultiModalInputs() + return MultiModalKwargs() processor = cached_get_processor(ctx.model_config.model) audio_feature_extractor = processor.feature_extractor @@ -255,7 +255,7 @@ def input_mapper_for_qwen2_audio( logger.error("Failed to process audio (%s)", multi_modal_data) raise - return MultiModalInputs(batch_data) + return MultiModalKwargs(batch_data) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_audio) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d801903f8f9fe..5ed57ce2f2f9d 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -58,7 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, - MultiModalInputs) + MultiModalKwargs) from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer @@ -567,10 +567,10 @@ def mm_input_mapper_for_qwen2_vl( *, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, -) -> MultiModalInputs: +) -> MultiModalKwargs: """Input mapper for Qwen2-VL.""" if data_type_key == "image" and isinstance(data, dict): - return MultiModalInputs({ + return MultiModalKwargs({ "image_embeds": data.get("image_embeds"), "image_grid_thw": data.get("image_grid_thw"), }) @@ -608,7 +608,7 @@ def mm_input_mapper_for_qwen2_vl( logger.error("Failed to process image (%s)", data) raise - return MultiModalInputs(batch_data) + return MultiModalKwargs(batch_data) image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 749750fc9c16e..13dfb21352529 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges, @@ -116,11 +116,11 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): data = [data] if len(data) == 0: - return MultiModalInputs() + return MultiModalKwargs() # If the audio inputs are embeddings, no need for preprocessing if is_list_of(data, torch.Tensor, check="all"): - return MultiModalInputs({"audio_embeds": data}) + return MultiModalKwargs({"audio_embeds": data}) audio_features = [] for audio_input in data: @@ -154,7 +154,7 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): # Remove the batch dimension because we're wrapping it in a list. audio_features.append(single_audio_features.squeeze(0)) - return MultiModalInputs({"audio_features": audio_features}) + return MultiModalKwargs({"audio_features": audio_features}) def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 53da2badb9b98..ae60e0d60bc79 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,7 +1,7 @@ -from .base import (BatchedTensorInputs, MultiModalDataBuiltins, - MultiModalDataDict, MultiModalInputs, - MultiModalPlaceholderDict, MultiModalPlaceholderMap, - MultiModalPlugin, NestedTensors) +from .base import MultiModalPlaceholderMap, MultiModalPlugin +from .inputs import (BatchedTensorInputs, MultiModalDataBuiltins, + MultiModalDataDict, MultiModalKwargs, + MultiModalPlaceholderDict, NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -17,7 +17,7 @@ "BatchedTensorInputs", "MultiModalDataBuiltins", "MultiModalDataDict", - "MultiModalInputs", + "MultiModalKwargs", "MultiModalPlaceholderDict", "MultiModalPlaceholderMap", "MultiModalPlugin", diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 04d71826f29fa..e71ae5feec1c6 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,5 +1,5 @@ from vllm.inputs.registry import InputContext -from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin +from vllm.multimodal.base import MultiModalKwargs, MultiModalPlugin class AudioPlugin(MultiModalPlugin): @@ -9,7 +9,7 @@ def get_data_key(self) -> str: return "audio" def _default_input_mapper(self, ctx: InputContext, data: object, - **mm_processor_kwargs) -> MultiModalInputs: + **mm_processor_kwargs) -> MultiModalKwargs: raise NotImplementedError("There is no default audio input mapper") def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 6b10d0c609f13..a94ed1dc68013 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,186 +1,26 @@ -import sys from abc import ABC, abstractmethod -from collections import UserDict, defaultdict -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, - NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar, - Union, cast, final) - -import numpy as np -import torch -import torch.types -from PIL import Image +from collections import defaultdict +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, + Optional, Sequence, Tuple, Type, TypeVar, Union) + from torch import nn -from typing_extensions import TypeAlias from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, - json_map_leaves, resolve_mm_processor_kwargs) +from vllm.utils import (get_allowed_kwarg_only_overrides, + resolve_mm_processor_kwargs) if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.sequence import SequenceGroupMetadata -logger = init_logger(__name__) - -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] -""" -Uses a list instead of a tensor if the dimensions of each element do not match. -""" - -BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] -""" -A dictionary containing nested tensors which have been batched via -:meth:`MultiModalInputs.batch`. -""" - -if sys.version_info < (3, 9): - # UserDict cannot be subscripted - class _MultiModalInputsBase(UserDict): - pass -else: - - class _MultiModalInputsBase(UserDict[str, NestedTensors]): - pass - - -class MultiModalInputs(_MultiModalInputsBase): - """ - A dictionary that represents the keyword arguments to - :meth:`~torch.nn.Module.forward`. - """ - - @staticmethod - def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: - """ - Recursively stacks lists of tensors when they all have the same shape. - """ - if isinstance(nested_tensors, torch.Tensor): - return nested_tensors - - if isinstance(nested_tensors, np.ndarray): - return torch.from_numpy(nested_tensors) - - if isinstance(nested_tensors, (int, float)): - return torch.tensor(nested_tensors) - - stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] - if not is_list_of(stacked, torch.Tensor, check="all"): - # Only tensors (not lists) can be stacked. - return stacked - - tensors_ = cast(List[torch.Tensor], stacked) - if any(t.shape != tensors_[0].shape for t in tensors_): - # The tensors have incompatible shapes and can't be stacked. - return tensors_ - - return torch.stack(tensors_) - - @staticmethod - def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: - """ - Batch multiple inputs together into a dictionary. - - The resulting dictionary has the same keys as the inputs. - If the corresponding value from each input is a tensor and they all - share the same shape, the output value is a single batched tensor; - otherwise, the output value is a list containing the original value - from each input. - """ - if len(inputs_list) == 0: - return {} - - item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) - - for inputs in inputs_list: - # For models that supports multiple modalities (e.g. Qwen2-VL), - # different modalities will return different data keys, - # so batch() should skip the same key check. - - for k, v in inputs.items(): - item_lists[k].append(v) - - return { - k: MultiModalInputs._try_stack(item_list) - for k, item_list in item_lists.items() - } - - @staticmethod - def as_kwargs( - batched_inputs: BatchedTensorInputs, - *, - device: torch.types.Device, - ) -> BatchedTensorInputs: - json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) - - json_mapped = json_map_leaves( - lambda x: x.to(device, non_blocking=True), - json_inputs, - ) - - return cast(BatchedTensorInputs, json_mapped) +from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs, + PlaceholderRange) - -_T = TypeVar("_T") - -MultiModalData: TypeAlias = Union[_T, List[_T]] -""" -Either a single data instance, or a list of data instances. - -The number of data instances allowed per modality is restricted by -`--limit-mm-per-prompt`. -""" - - -@final -class MultiModalDataBuiltins(TypedDict, total=False): - """Modality types that are predefined by vLLM.""" - - image: MultiModalData[Image.Image] - """The input image(s).""" - - audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]] - """The input audio item(s) and corresponding sampling rate(s).""" - - -MultiModalDataDict = Union[MultiModalDataBuiltins, - Mapping[str, MultiModalData[object]]] -""" -A dictionary containing an item for each modality type to input. - -Note: - This dictionary also accepts modality keys defined outside - :class:`MultiModalDataBuiltins` as long as a customized plugin is registered - through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. - Read more on that :ref:`here `. -""" - - -class PlaceholderRange(TypedDict): - """ - Placeholder location information for multi-modal data. - - For example: - Prompt: AAAA BBBB What is in these images? - Images A and B will have: - A: { "offset": 0, "length": 4 } - B: { "offset": 5, "length": 4 } - """ - - offset: int - """The start index of the placeholder in the prompt.""" - - length: int - """The length of the placeholder.""" - - -MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]] -""" -A dictionary containing placeholder ranges. -""" +logger = init_logger(__name__) MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], - MultiModalInputs] + MultiModalKwargs] """ Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers @@ -195,6 +35,7 @@ class PlaceholderRange(TypedDict): model. This does not include tokens that correspond to the input text. """ +_T = TypeVar("_T") N = TypeVar("N", bound=Type[nn.Module]) @@ -229,7 +70,7 @@ def _default_input_mapper( ctx: InputContext, data: MultiModalData[object], **mm_processor_kwargs, - ) -> MultiModalInputs: + ) -> MultiModalKwargs: """ Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to @@ -273,7 +114,7 @@ def wrapper(model_cls: N) -> N: def map_input(self, model_config: "ModelConfig", data: MultiModalData[object], - mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs: + mm_processor_kwargs: Dict[str, Any]) -> MultiModalKwargs: """ Transform the data into a dictionary of model inputs using the input mapper registered for that model. @@ -500,7 +341,7 @@ def from_seq_group( def append_items_from_seq_group( self, positions: range, multi_modal_items: List[_T], - multi_modal_placeholders: List[PlaceholderRange]) -> List[_T]: + multi_modal_placeholders: Sequence[PlaceholderRange]) -> List[_T]: """ Adds the multi-modal items that intersect ```positions`` to this placeholder map and returns the intersecting items. diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 3f6bb6c8338d2..589b46266b08d 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -10,7 +10,7 @@ from vllm.transformers_utils.processor import get_image_processor from vllm.utils import is_list_of -from .base import MultiModalData, MultiModalInputs, MultiModalPlugin +from .base import MultiModalData, MultiModalKwargs, MultiModalPlugin if TYPE_CHECKING: from vllm.config import ModelConfig @@ -43,12 +43,12 @@ def _default_input_mapper( ctx: InputContext, data: MultiModalData[object], **mm_processor_kwargs, - ) -> MultiModalInputs: + ) -> MultiModalKwargs: model_config = ctx.model_config # Processed by input processor if isinstance(data, BatchFeature): - return MultiModalInputs(data.data) + return MultiModalKwargs(data.data) # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): @@ -78,11 +78,11 @@ def _default_input_mapper( type(image_processor).__name__) raise - return MultiModalInputs(batch_data) + return MultiModalKwargs(batch_data) # Image embedding elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor): - return MultiModalInputs({"image_embeds": data}) + return MultiModalKwargs({"image_embeds": data}) raise TypeError(f"Invalid image type: {type(data)}") diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py new file mode 100644 index 0000000000000..16f0d158556a6 --- /dev/null +++ b/vllm/multimodal/inputs.py @@ -0,0 +1,229 @@ +import sys +from collections import UserDict, defaultdict +from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple, + TypedDict, TypeVar, Union, cast, final) + +import numpy as np +import torch +import torch.types +from PIL.Image import Image +from typing_extensions import TypeAlias + +from vllm.utils import JSONTree, is_list_of, json_map_leaves + +_T = TypeVar("_T") + +# yapf: disable +ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] +""" +A :class:`transformers.image_utils.ImageInput` representing a single image, +which can be passed to a HuggingFace :code:`ImageProcessor`. +""" + +VideoItem: TypeAlias = Union[ + List[Image], + np.ndarray, + torch.Tensor, + List[np.ndarray], + List[torch.Tensor], +] +""" + +A :class:`transformers.image_utils.VideoInput` representing a single video, +which can be passed to a HuggingFace :code:`VideoProcessor`. +""" + +AudioItem: TypeAlias = Union[ + np.ndarray, + List[float], + Tuple[np.ndarray, float], # DEPRECATED: Use mm_processor_kwargs instead +] +""" +Represents a single audio that can be inputted to a HuggingFace +:code:`AudioProcessor`. +""" +# yapf: enable + +MultiModalData: TypeAlias = Union[_T, List[_T]] +""" +Either a single data item, or a list of data items. + +The number of data items allowed per modality is restricted by +:code:`--limit-mm-per-prompt`. +""" + + +@final +class MultiModalDataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: MultiModalData[ImageItem] + """The input image(s).""" + + video: MultiModalData[VideoItem] + """The input video(s).""" + + audio: MultiModalData[AudioItem] + """The input audio(s).""" + + +MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]] +""" +A dictionary containing an entry for each modality type to input. + +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalDataBuiltins` as long as a customized plugin + is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. +""" + + +class PlaceholderRange(TypedDict): + """ + Placeholder location information for multi-modal data. + + For example: + Prompt: AAAA BBBB What is in these images? + Images A and B will have: + A: { "offset": 0, "length": 4 } + B: { "offset": 5, "length": 4 } + """ + + offset: int + """The start index of the placeholder in the prompt.""" + + length: int + """The length of the placeholder.""" + + +NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] +""" +Uses a list instead of a tensor if the dimensions of each element do not match. +""" + +BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] +""" +A dictionary containing nested tensors which have been batched via +:meth:`MultiModalKwargs.batch`. +""" + +if sys.version_info < (3, 9): + # UserDict cannot be subscripted + class _MultiModalKwargsBase(UserDict): + pass +else: + + class _MultiModalKwargsBase(UserDict[str, NestedTensors]): + pass + + +class MultiModalKwargs(_MultiModalKwargsBase): + """ + A dictionary that represents the keyword arguments to + :meth:`~torch.nn.Module.forward`. + """ + + @staticmethod + def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: + """ + Stack the inner dimensions that have the same shape in + a nested list of tensors. + + Thus, a dimension represented by a list means that the inner + dimensions are different for each element along that dimension. + """ + if isinstance(nested_tensors, torch.Tensor): + return nested_tensors + + stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors] + if not is_list_of(stacked, torch.Tensor, check="all"): + # Only tensors (not lists) can be stacked. + return stacked + + tensors_ = cast(List[torch.Tensor], stacked) + if any(t.shape != tensors_[0].shape for t in tensors_): + # The tensors have incompatible shapes and can't be stacked. + return tensors_ + + return torch.stack(tensors_) + + @staticmethod + def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: + """ + Batch multiple inputs together into a dictionary. + + The resulting dictionary has the same keys as the inputs. + If the corresponding value from each input is a tensor and they all + share the same shape, the output value is a single batched tensor; + otherwise, the output value is a list containing the original value + from each input. + """ + if len(inputs_list) == 0: + return {} + + # We need to consider the case where each item in the batch + # contains different modalities (i.e. different keys). + item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) + + for inputs in inputs_list: + for k, v in inputs.items(): + item_lists[k].append(v) + + return { + k: MultiModalKwargs._try_stack(item_list) + for k, item_list in item_lists.items() + } + + @staticmethod + def as_kwargs( + batched_inputs: BatchedTensorInputs, + *, + device: torch.types.Device, + ) -> BatchedTensorInputs: + json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) + + json_mapped = json_map_leaves( + lambda x: x.to(device, non_blocking=True), + json_inputs, + ) + + return cast(BatchedTensorInputs, json_mapped) + + +MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] +""" +A dictionary containing placeholder ranges. +""" + + +class MultiModalInputsV2(TypedDict): + """ + Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`, + ready to be passed to vLLM internals. + """ + + type: Literal["multimodal"] + """The type of inputs.""" + + prompt: str + """ + The original, unprocessed prompt text. + + Note: + Since prompt text is not required by vLLM internals, we leave this + unprocessed to save CPU computation. You can still call + :code:`tokenizer.decode(prompt_token_ids)` to get the processed text. + """ + + prompt_token_ids: List[int] + """The processed token IDs which includes placeholder tokens.""" + + mm_kwargs: MultiModalKwargs + """Keyword arguments to be directly passed to the model after batching.""" + + mm_placeholders: MultiModalPlaceholderDict + """ + For each modality, information about the placeholder tokens in + :code:`prompt_token_ids`. + """ diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py new file mode 100644 index 0000000000000..8d2e9987131ef --- /dev/null +++ b/vllm/multimodal/processing.py @@ -0,0 +1,275 @@ +from dataclasses import dataclass +from functools import lru_cache, partial +from typing import (Any, Callable, Collection, Generic, List, Mapping, + Optional, TypedDict, TypeVar, final) + +from transformers import BatchFeature +from typing_extensions import TypeAlias + +from vllm.inputs import InputProcessingContext +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import is_list_of + +from .inputs import (AudioItem, ImageItem, MultiModalDataDict, + MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, + VideoItem) + +_T = TypeVar("_T") + +ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]] +""" +Given the original data item, HF-processed data, and index of the processed +item, output the replacement token IDs to be allocated in vLLM. +""" + + +@dataclass +class ModalityProcessingMetadata(Generic[_T]): + placeholder_replacements: Mapping[str, ReplacementFunc] + """ + A dictionary where each item represents the original placeholder in the + prompt text and the corresponding replacement. + """ + + +class MultiModalProcessingMetadataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: ModalityProcessingMetadata[ImageItem] + video: ModalityProcessingMetadata[VideoItem] + audio: ModalityProcessingMetadata[AudioItem] + + +MultiModalProcessingMetadata: TypeAlias = \ + Mapping[str, ModalityProcessingMetadata[Any]] +""" +A dictionary containing an entry for each modality type to process. + +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin + is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. +""" + +MultiModalMultiData: TypeAlias = List[_T] +""" +A list of data items, where the number of data items allowed +per modality is restricted by :code:`--limit-mm-per-prompt`. +""" + + +@final +class MultiModalMultiDataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: MultiModalMultiData[ImageItem] + """The input images.""" + + video: MultiModalMultiData[VideoItem] + """The input videos.""" + + audio: MultiModalMultiData[AudioItem] + """The input audios.""" + + +MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]] +""" +A dictionary containing an entry for each modality type to input. + +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalMultiDataBuiltins` as long as a customized plugin + is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. +""" + + +def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict: + """ + Convert a :class:`MultiModalDataDict` containing single data items + to a :class:`MultiModalMultiDataDict` containing multiple data items + per entry. + """ + multi_data: Mapping[str, MultiModalMultiData[Any]] = {} + + for k, v in data.items(): + # yapf: disable + if k == "image": + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + elif k == "video": + # Special case since even a single item can be a list + multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index] + elif k == "audio": + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + else: + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + # yapf: enable + + return multi_data + + +def encode_no_special_tokens( + tokenizer: AnyTokenizer, + text: str, +) -> List[int]: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.encode(text, add_special_tokens=False)`. + """ + if isinstance(tokenizer, MistralTokenizer): + return tokenizer.tokenizer.encode(text, bos=False, eos=False) + + return tokenizer.encode(text, add_special_tokens=False) + + +@lru_cache +def candidate_placeholders( + tokenizer: AnyTokenizer, + placeholder_text: str, +) -> Collection[List[int]]: + """Generate token ID sequences that may represent a placeholder text.""" + # When the placeholder text is not mapped to a special token ID, + # it may be tokenized differently based on whether it is at the start/end + # of the string. So, we go through each combination of whether the text + # is at the start and end boundaries of the string + + # Matches the placeholder when it is in the middle of the string + start_id, = encode_no_special_tokens(tokenizer, "a") + end_id, = encode_no_special_tokens(tokenizer, "b") + + candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text) + + start_id_, *candidate_a = encode_no_special_tokens( + tokenizer, + f"a{placeholder_text}", + ) + assert start_id == start_id_ + + start_id_, *candidate_ab, end_id_ = encode_no_special_tokens( + tokenizer, + f"a{placeholder_text}b", + ) + assert start_id == start_id_ and end_id == end_id_ + + *candidate_b, end_id_ = encode_no_special_tokens( + tokenizer, + f"{placeholder_text}b", + ) + assert end_id == end_id_ + + # Remove duplicates (need to convert to tuple to be hashable) + unique_candidates = { + tuple(c) + for c in [candidate_basic, candidate_a, candidate_ab, candidate_b] + } + + # Convert back to list + return [list(c) for c in unique_candidates] + + +def apply_placeholders( + token_ids: List[int], + placeholder_ids: List[int], + get_replacement_ids: Callable[[], List[int]], +) -> Optional[PlaceholderRange]: + """ + Find the first occurrence of :code:`placeholder_ids`, + and replace it with the output of :code:`get_replacement_ids`. + + This function updates :code:`token_ids` in place. + """ + placeholder_length = len(placeholder_ids) + + for start_idx in range(len(token_ids) - placeholder_length + 1): + if token_ids[start_idx:placeholder_length] == placeholder_ids: + token_ids[start_idx:placeholder_length] = get_replacement_ids() + + return PlaceholderRange(offset=start_idx, + length=placeholder_length) + + return None + + +class MultiModalProcessor: + """ + Helper class to process multi-modal inputs to be used in vLLM. + """ + + def __init__( + self, + ctx: InputProcessingContext, + metadata: MultiModalProcessingMetadata, + ) -> None: + super().__init__() + + self.ctx = ctx + self.metadata = metadata + + def __call__( + self, + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + return self.apply(prompt, mm_data, mm_processor_kwargs) + + def apply( + self, + prompt: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + tokenizer = self.ctx.tokenizer + hf_processor = self.ctx.get_hf_processor() + + processed_inputs = hf_processor( + text=prompt, # type: ignore + **mm_data, + **mm_processor_kwargs, + ) + new_token_ids, = processed_inputs.pop("input_ids").tolist() + mm_kwargs = MultiModalKwargs(processed_inputs) + + mm_placeholders: Mapping[str, List[PlaceholderRange]] = {} + + for modality, orig_inputs in to_multi_format(mm_data).items(): + assert isinstance(orig_inputs, list) + + metadata = self.metadata[modality] + placeholder_replacements = metadata.placeholder_replacements + + modality_placeholders: List[PlaceholderRange] = [] + + for item_idx, orig_item in enumerate(orig_inputs): + for match_text, replace_fn in placeholder_replacements.items(): + candidates = candidate_placeholders(tokenizer, match_text) + get_replacement_ids = partial( + replace_fn, + orig_item, + processed_inputs, + item_idx, + ) + + for match_ids in candidates: + # TODO(youkaichao): Don't update new_token_ids + placeholders = apply_placeholders( + new_token_ids, + match_ids, + get_replacement_ids, + ) + + if placeholders is not None: + modality_placeholders.append(placeholders) + + # yapf: disable + mm_placeholders[modality] = modality_placeholders # type: ignore[index] + # yapf: enable + + return MultiModalInputsV2( + type="multimodal", + prompt=prompt, + prompt_token_ids=new_token_ids, + mm_kwargs=mm_kwargs, + mm_placeholders=mm_placeholders, + ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index bce2f4c6abe5b..0c30e60d2e91c 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,13 +1,20 @@ import functools from collections import UserDict -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence +from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, + Sequence, Type, TypeVar) +import torch.nn as nn +from typing_extensions import TypeAlias + +from vllm.inputs import InputProcessingContext from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer from .audio import AudioPlugin -from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, - MultiModalPlugin, MultiModalTokensCalc, NestedTensors) +from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin +from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors +from .processing import MultiModalProcessor from .video import VideoPlugin if TYPE_CHECKING: @@ -15,6 +22,16 @@ logger = init_logger(__name__) +N = TypeVar("N", bound=Type[nn.Module]) + +MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], + MultiModalProcessor] +""" +Constructs a :class:`MultiModalProcessor` instance from the context. + +The processing metadata should be derived from the context. +""" + class _MultiModalLimits(UserDict): """ @@ -45,6 +62,9 @@ def __init__( plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: self._plugins = {p.get_data_key(): p for p in plugins} + self._processor_factories: Dict[Type[nn.Module], + MultiModalProcessorFactory] = {} + # This is used for non-multimodal models self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} @@ -103,7 +123,7 @@ def map_input( model_config: "ModelConfig", data: MultiModalDataDict, mm_processor_kwargs: Optional[Dict[str, Any]] = None, - ) -> MultiModalInputs: + ) -> MultiModalKwargs: """ Apply an input mapper to the data passed to the model. @@ -139,7 +159,7 @@ def map_input( merged_dict[input_key] = input_tensor - return MultiModalInputs(merged_dict) + return MultiModalKwargs(merged_dict) def create_input_mapper(self, model_config: "ModelConfig"): """ @@ -243,3 +263,59 @@ def get_mm_limits_per_prompt( This should be called after :meth:`init_mm_limits_per_prompt`. """ return self._limits_by_model[model_config] + + def register_processor( + self, + factory: MultiModalProcessorFactory, + ): + """ + Register a multi-modal processor to a model class. + + When the model receives multi-modal data, the provided function is + invoked to transform the data into a dictionary of model inputs. + + See also: + - :ref:`input_processing_pipeline` + - :ref:`enabling_multimodal_inputs` + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._processor_factories: + logger.warning( + "Model class %s already has an input mapper " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._processor_factories[model_cls] = factory + + return model_cls + + return wrapper + + def has_processor(self, model_config: "ModelConfig") -> bool: + """ + Test whether a multi-modal processor is defined for a specific model. + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + return model_cls in self._processor_factories + + def create_processor( + self, + model_config: "ModelConfig", + tokenizer: AnyTokenizer, + ) -> MultiModalProcessor: + """ + Create a multi-modal processor for a specific model and tokenizer. + """ + + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + processor_factory = self._processor_factories[model_cls] + + ctx = InputProcessingContext(model_config, tokenizer) + return processor_factory(ctx) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 6c2c6720f4276..27487d9d5946e 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -8,7 +8,7 @@ from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.tokenizer import get_tokenizer -from .base import MultiModalData, MultiModalInputs +from .base import MultiModalData, MultiModalKwargs from .image import ImagePlugin if TYPE_CHECKING: @@ -54,7 +54,7 @@ def _default_input_mapper( ctx: InputContext, data: MultiModalData[object], **mm_processor_kwargs, - ) -> MultiModalInputs: + ) -> MultiModalKwargs: model_config = ctx.model_config if isinstance(data, list) and len(data) == 1: @@ -78,7 +78,7 @@ def _default_input_mapper( logger.error("Failed to process video (%s)", data) raise - return MultiModalInputs(batch_data) + return MultiModalKwargs(batch_data) raise TypeError(f"Invalid video type: {type(data)}") diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d7ddc7ec4447..2d35e525ec5cf 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -435,7 +435,7 @@ def n_blocks(self) -> int: def prompt(self) -> Optional[str]: inputs = self.inputs - if inputs["type"] == "token": + if inputs["type"] == "token" or inputs["type"] == "multimodal": return inputs.get("prompt") assert_never(inputs) @@ -444,7 +444,7 @@ def prompt(self) -> Optional[str]: def prompt_token_ids(self) -> List[int]: inputs = self.inputs - if inputs["type"] == "token": + if inputs["type"] == "token" or inputs["type"] == "multimodal": return inputs.get("prompt_token_ids", []) assert_never(inputs) @@ -453,7 +453,7 @@ def prompt_token_ids(self) -> List[int]: def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs - if inputs["type"] == "token": + if inputs["type"] == "token" or inputs["type"] == "multimodal": return None assert_never(inputs) @@ -465,23 +465,35 @@ def multi_modal_data(self) -> "MultiModalDataDict": if inputs["type"] == "token": return inputs.get("multi_modal_data", {}) + if inputs["type"] == "multimodal": + return inputs.get("mm_kwargs", {}) + assert_never(inputs) - @cached_property - def mm_processor_kwargs(self) -> Dict[str, Any]: + @property + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: inputs = self.inputs if inputs["type"] == "token": - return inputs.get("mm_processor_kwargs", {}) + return inputs.get("multi_modal_placeholders", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_placeholders", {}) assert_never(inputs) - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + @cached_property + def mm_processor_kwargs(self) -> Dict[str, Any]: inputs = self.inputs if inputs["type"] == "token": - return inputs.get("multi_modal_placeholders", {}) + return { + "needs_mm_mapper": True, + **inputs.get("mm_processor_kwargs", {}), + } + + if inputs["type"] == "multimodal": + return {} assert_never(inputs) @@ -953,6 +965,13 @@ def __post_init__(self): else: self.token_chunk_size = 1 + @property + def needs_mm_mapper(self): + # Interim measure so we can handle models that have yet to be + # updated to use the new multi-modal processor + return (self.mm_processor_kwargs is not None + and self.mm_processor_kwargs.get("needs_mm_mapper", False)) + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 17cc0ad1a4a3a..13d65d8304d19 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -18,7 +18,7 @@ "CUDA and ROCm flash attention backend.") from err from vllm.logger import init_logger -from vllm.multimodal import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) @@ -274,7 +274,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), **kwargs, ) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 64cc18149d6c5..9b1bc841d905a 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -8,10 +8,11 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderLLMInputs, InputRegistry, PromptType) + EncoderDecoderInputs, InputRegistry, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import CompletionOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -39,6 +40,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, ) -> None: @@ -128,8 +130,11 @@ def __init__( self.generation_config_fields = _load_generation_config_dict( model_config) - self.input_preprocessor = InputPreprocessor(model_config, - self.tokenizer) + self.input_preprocessor = InputPreprocessor( + model_config, + self.tokenizer, + mm_registry, + ) self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( model_config) @@ -213,7 +218,7 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs], + processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -436,7 +441,7 @@ def check_health(self) -> None: self.model_executor.check_health() def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderLLMInputs]): + EncoderDecoderInputs]): prompt_ids = inputs.get("prompt_token_ids") if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index 8ebbf6db939bc..994af7c5a455f 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -5,7 +5,7 @@ from vllm.attention import AttentionMetadata from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad from vllm.worker.cpu_model_runner import (CPUModelRunner, @@ -287,7 +287,7 @@ def execute_model( kv_caches, "attn_metadata": model_input.attn_metadata, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), "intermediate_tensors": intermediate_tensors, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index fdd72a452f2ad..c040f9fa191e2 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalPlaceholderMap) + MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.transformers_utils.config import uses_mrope @@ -159,7 +159,11 @@ def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata, if not mm_data: return - mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) + if "needs_mm_mapper" in mm_processor_kwargs: + mm_kwargs = self.multi_modal_input_mapper(mm_data, + mm_processor_kwargs) + else: + mm_kwargs = mm_data # special processing for mrope position deltas. mrope_positions = None @@ -201,7 +205,7 @@ def _prepare_prompt( slot_mapping: List[int] = [] seq_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -226,7 +230,7 @@ def _prepare_prompt( ._compute_multi_modal_input( seq_group_metadata, seq_data, computed_len, seq_group_metadata.mm_processor_kwargs) - multi_modal_inputs_list.append(mm_kwargs) + multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( placeholder_map) @@ -298,7 +302,7 @@ def _prepare_prompt( multi_modal_placeholder_index_maps=placeholder_index_maps, ) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) @@ -527,7 +531,7 @@ def execute_model( kv_caches, "attn_metadata": model_input.attn_metadata, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), "intermediate_tensors": intermediate_tensors, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index ff288d5ca1512..37cfcbf13d7a3 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -8,7 +8,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.multimodal import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) @@ -104,7 +104,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device)) if (self.observability_config is not None diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 90a43196084ea..008e0c9745994 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -18,7 +18,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.utils import get_architecture_class_name -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, @@ -206,7 +206,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), **seqlen_agnostic_kwargs) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 328dab598f8ef..384d94c46876f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -38,7 +38,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalPlaceholderMap, + MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping @@ -240,7 +240,7 @@ def __init__( prompt_adapter_request: Optional[PromptAdapterRequest] = None, # Multi-modal inputs. - multi_modal_inputs: Optional[MultiModalInputs] = None, + multi_modal_kwargs: Optional[MultiModalKwargs] = None, multi_modal_placeholder_maps: Optional[Dict[ str, MultiModalPlaceholderMap]] = None, @@ -361,7 +361,7 @@ def __init__( prompt_adapter_prompt_mapping or []) self.prompt_adapter_request = prompt_adapter_request - self.multi_modal_inputs = multi_modal_inputs + self.multi_modal_kwargs = multi_modal_kwargs self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.prefix_cache_hit = prefix_cache_hit @@ -646,10 +646,13 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, if not mm_data: return - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) - inter_data.multi_modal_inputs = mm_kwargs + if seq_group_metadata.needs_mm_mapper: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, seq_group_metadata.mm_processor_kwargs) + else: + mm_kwargs = mm_data + + inter_data.multi_modal_kwargs = mm_kwargs inter_data.multi_modal_placeholder_maps = placeholder_maps # special processing for mrope position deltas. @@ -923,11 +926,11 @@ def build(self) -> ModelInputForGPU: ) # Multi-modal data. - multi_modal_inputs_list = [ - data.multi_modal_inputs for data in self.inter_data_list - if data.multi_modal_inputs is not None + multi_modal_kwargs_list = [ + data.multi_modal_kwargs for data in self.inter_data_list + if data.multi_modal_kwargs is not None ] - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return self.model_input_cls( input_tokens=input_tokens_tensor, @@ -1640,7 +1643,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), **seqlen_agnostic_kwargs) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 2da22cbfc7cb5..b9fc41bfe175a 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs) + MultiModalKwargs) from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -122,7 +122,7 @@ def _prepare_prompt( input_block_ids: List[int] = [] seq_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -144,12 +144,16 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data: - # Process multi-modal data - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs, - ) - multi_modal_inputs_list.append(mm_kwargs) + if seq_group_metadata.needs_mm_mapper: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs, + ) + else: + mm_kwargs = mm_data + + multi_modal_kwargs_list.append(mm_kwargs) max_seq_len = max(seq_lens) assert max_seq_len > 0 @@ -167,7 +171,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return (input_tokens, input_positions, input_block_ids, seq_lens, multi_modal_kwargs) @@ -314,7 +318,7 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), ) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index c9c87ea748081..2b69d579c17ce 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.openvino import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalPlaceholderMap) + MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import SequenceGroupMetadata from vllm.worker.model_runner_base import ModelRunnerBase @@ -102,7 +102,7 @@ def _prepare_model_input( seq_lens: List[int] = [] past_lens: List[int] = [] query_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -222,11 +222,16 @@ def _prepare_model_input( mm_data, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata. - mm_processor_kwargs) - multi_modal_inputs_list.append(mm_kwargs) + if seq_group_metadata.needs_mm_mapper: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs, + ) + else: + mm_kwargs = mm_data + + multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( @@ -275,7 +280,7 @@ def _prepare_model_input( multi_modal_placeholder_index_maps=placeholder_index_maps, ) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return ModelInput( input_tokens, @@ -341,7 +346,7 @@ def execute_model( kv_caches, "attn_metadata": attn_metadata, - **MultiModalInputs.as_kwargs(multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs or {}, device=self.device), } diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index bae8b469767b2..b34c4a413156d 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -18,7 +18,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalPlaceholderMap, + MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata @@ -160,7 +160,7 @@ def _prepare_prompt( input_positions: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -191,8 +191,16 @@ def _prepare_prompt( mm_data, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - mm_kwargs = self.runner.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) + if seq_group_metadata.needs_mm_mapper: + mm_kwargs = self.runner.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs, + ) + else: + mm_kwargs = mm_data + + multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( @@ -264,7 +272,7 @@ def _prepare_prompt( block_tables=torch.tensor([], device=self.device, dtype=torch.int), ) - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) @@ -565,7 +573,7 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device)) # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: From d403af6abf3648816adb3d94fb131af131a5291f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 9 Nov 2024 04:00:50 +0000 Subject: [PATCH 02/16] Fix typo Signed-off-by: DarkLight1337 --- vllm/worker/hpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 92d6552b2f428..1ff30d685c6b1 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -716,7 +716,7 @@ def _prepare_prompt( context_lens: List[int] = [] query_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] - multi_model_kwargs_list: List[MultiModalKwargs] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] if len(seq_group_metadata_list) == 0: return PreparePromptMetadata.empty() @@ -777,7 +777,7 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data: mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_model_kwargs_list.append(mm_kwargs) + multi_modal_kwargs_list.append(mm_kwargs) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -876,7 +876,7 @@ def _prepare_prompt( multi_modal_placeholder_index_maps= None # FIXME(kzawora): mutli-modality will not work here ) - multi_modal_kwargs = MultiModalKwargs.batch(multi_model_kwargs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return PreparePromptMetadata(input_tokens=input_tokens, input_positions=input_positions, From 9c23c3fd568de6a0aa8d40b0ad5f73583c3b5c67 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 9 Nov 2024 04:03:46 +0000 Subject: [PATCH 03/16] Add back handling of other data types Signed-off-by: DarkLight1337 --- vllm/multimodal/inputs.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 16f0d158556a6..6727ca25af467 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -135,6 +135,12 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: """ if isinstance(nested_tensors, torch.Tensor): return nested_tensors + + # TODO: Remove these once all models have been migrated + if isinstance(nested_tensors, np.ndarray): + return torch.from_numpy(nested_tensors) + if isinstance(nested_tensors, (int, float)): + return torch.tensor(nested_tensors) stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors] if not is_list_of(stacked, torch.Tensor, check="all"): From 6955998550bfe0d704c52ebda4bb8371b6ea571f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 9 Nov 2024 04:29:25 +0000 Subject: [PATCH 04/16] Proper detection of whether processor is used Signed-off-by: DarkLight1337 --- vllm/multimodal/inputs.py | 2 +- vllm/sequence.py | 12 +-------- vllm/worker/cpu_model_runner.py | 37 +++++++++++++++++----------- vllm/worker/model_runner.py | 10 +++++--- vllm/worker/neuron_model_runner.py | 12 ++++----- vllm/worker/openvino_model_runner.py | 12 ++++----- vllm/worker/xpu_model_runner.py | 10 ++++---- 7 files changed, 48 insertions(+), 47 deletions(-) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 6727ca25af467..3a005fd625708 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -135,7 +135,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: """ if isinstance(nested_tensors, torch.Tensor): return nested_tensors - + # TODO: Remove these once all models have been migrated if isinstance(nested_tensors, np.ndarray): return torch.from_numpy(nested_tensors) diff --git a/vllm/sequence.py b/vllm/sequence.py index 2d35e525ec5cf..bbf97576a2402 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -487,10 +487,7 @@ def mm_processor_kwargs(self) -> Dict[str, Any]: inputs = self.inputs if inputs["type"] == "token": - return { - "needs_mm_mapper": True, - **inputs.get("mm_processor_kwargs", {}), - } + return inputs.get("mm_processor_kwargs", {}) if inputs["type"] == "multimodal": return {} @@ -965,13 +962,6 @@ def __post_init__(self): else: self.token_chunk_size = 1 - @property - def needs_mm_mapper(self): - # Interim measure so we can handle models that have yet to be - # updated to use the new multi-modal processor - return (self.mm_processor_kwargs is not None - and self.mm_processor_kwargs.get("needs_mm_mapper", False)) - @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7d39d08015bbf..0d44e40209faa 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -146,23 +146,29 @@ def build(self) -> ModelInputForCPU: query_lens=seq_lens, ) - def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata, - seq_data: SequenceData, computed_len: int, - mm_processor_kwargs: Dict[str, Any]): - + def _compute_multi_modal_input( + self, + seq_data: SequenceData, + computed_len: int, + seq_group_metadata: SequenceGroupMetadata, + ): # NOTE: mm_data only includes the subset of multi-modal items that # intersect with the current prefill positions. mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( - seq_group, range(computed_len, len(seq_data.get_token_ids()))) + seq_group_metadata, + range(computed_len, len(seq_data.get_token_ids())), + ) if not mm_data: - return + return None, None, None - if "needs_mm_mapper" in mm_processor_kwargs: - mm_kwargs = self.multi_modal_input_mapper(mm_data, - mm_processor_kwargs) - else: + if self.runner.mm_registry.has_processor(self.runner.model_config): mm_kwargs = mm_data + else: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) # special processing for mrope position deltas. mrope_positions = None @@ -225,10 +231,13 @@ def _prepare_prompt( mrope_positions = None if seq_group_metadata.multi_modal_data: - mm_kwargs, placeholder_maps, mrope_positions = self \ - ._compute_multi_modal_input( - seq_group_metadata, seq_data, computed_len, - seq_group_metadata.mm_processor_kwargs) + ( + mm_kwargs, + placeholder_maps, + mrope_positions, + ) = self._compute_multi_modal_input(seq_data, computed_len, + seq_group_metadata) + multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 77441794b5490..4b846c5d827dc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -658,11 +658,13 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, if not mm_data: return - if seq_group_metadata.needs_mm_mapper: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, seq_group_metadata.mm_processor_kwargs) - else: + if self.runner.mm_registry.has_processor(self.runner.model_config): mm_kwargs = mm_data + else: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + seq_group_metadata.mm_processor_kwargs, + ) inter_data.multi_modal_kwargs = mm_kwargs inter_data.multi_modal_placeholder_maps = placeholder_maps diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index b9fc41bfe175a..ae4eb6ba6eaec 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -67,7 +67,8 @@ def __init__( self.pin_memory = is_pin_memory_available() # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ .create_input_mapper(self.model_config) # Lazy initialization. @@ -144,14 +145,13 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data: - if seq_group_metadata.needs_mm_mapper: + if self.mm_registry.has_processor(self.model_config): + mm_kwargs = mm_data + else: mm_kwargs = self.multi_modal_input_mapper( mm_data, - mm_processor_kwargs=seq_group_metadata. - mm_processor_kwargs, + seq_group_metadata.mm_processor_kwargs, ) - else: - mm_kwargs = mm_data multi_modal_kwargs_list.append(mm_kwargs) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 2b69d579c17ce..6000e5dfe4e30 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -70,7 +70,8 @@ def __init__( ) # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ .create_input_mapper(self.model_config) # Lazy initialization. @@ -222,14 +223,13 @@ def _prepare_model_input( mm_data, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - if seq_group_metadata.needs_mm_mapper: + if self.mm_registry.has_processor(self.model_config): + mm_kwargs = mm_data + else: mm_kwargs = self.multi_modal_input_mapper( mm_data, - mm_processor_kwargs=seq_group_metadata. - mm_processor_kwargs, + seq_group_metadata.mm_processor_kwargs, ) - else: - mm_kwargs = mm_data multi_modal_kwargs_list.append(mm_kwargs) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index b34c4a413156d..e6322e095bbb9 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -191,14 +191,14 @@ def _prepare_prompt( mm_data, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) - if seq_group_metadata.needs_mm_mapper: + if self.runner.mm_registry.has_processor( + self.runner.model_config): + mm_kwargs = mm_data + else: mm_kwargs = self.runner.multi_modal_input_mapper( mm_data, - mm_processor_kwargs=seq_group_metadata. - mm_processor_kwargs, + seq_group_metadata.mm_processor_kwargs, ) - else: - mm_kwargs = mm_data multi_modal_kwargs_list.append(mm_kwargs) From 10818dd374319f428cda4b46456e9c607843d029 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 9 Nov 2024 04:36:24 +0000 Subject: [PATCH 05/16] Make this a cached property as well Signed-off-by: DarkLight1337 --- vllm/sequence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index bbf97576a2402..18d5675eec128 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -470,7 +470,7 @@ def multi_modal_data(self) -> "MultiModalDataDict": assert_never(inputs) - @property + @cached_property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: inputs = self.inputs From 264c5c6a46ccb2c0032ff879ef1669a6c4fcac9e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 9 Nov 2024 04:50:25 +0000 Subject: [PATCH 06/16] Cleanup Signed-off-by: DarkLight1337 --- vllm/multimodal/audio.py | 11 ++++++++--- vllm/multimodal/base.py | 1 + vllm/multimodal/image.py | 10 +++------- vllm/multimodal/video.py | 20 +++++--------------- 4 files changed, 17 insertions(+), 25 deletions(-) diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index 092a8aac15dd4..1a230602966d4 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,6 +1,7 @@ from vllm.inputs.registry import InputContext -from .base import MultiModalKwargs, MultiModalPlugin +from .base import MultiModalPlugin +from .inputs import AudioItem, MultiModalData, MultiModalKwargs class AudioPlugin(MultiModalPlugin): @@ -9,8 +10,12 @@ class AudioPlugin(MultiModalPlugin): def get_data_key(self) -> str: return "audio" - def _default_input_mapper(self, ctx: InputContext, data: object, - **mm_processor_kwargs) -> MultiModalKwargs: + def _default_input_mapper( + self, + ctx: InputContext, + data: MultiModalData[AudioItem], + **mm_processor_kwargs, + ) -> MultiModalKwargs: raise NotImplementedError("There is no default audio input mapper") def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 049df7f35efd7..6816c66dd1b79 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -133,6 +133,7 @@ def map_input( - :ref:`input_processing_pipeline` - :ref:`enabling_multimodal_inputs` """ + # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 589b46266b08d..97bbce1ce1570 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -3,14 +3,14 @@ import torch from PIL import Image -from transformers.image_processing_base import BatchFeature from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.processor import get_image_processor from vllm.utils import is_list_of -from .base import MultiModalData, MultiModalKwargs, MultiModalPlugin +from .base import MultiModalPlugin +from .inputs import ImageItem, MultiModalData, MultiModalKwargs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -41,15 +41,11 @@ def _get_hf_image_processor( def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[object], + data: MultiModalData[ImageItem], **mm_processor_kwargs, ) -> MultiModalKwargs: model_config = ctx.model_config - # Processed by input processor - if isinstance(data, BatchFeature): - return MultiModalKwargs(data.data) - # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): image_processor = self._get_hf_image_processor( diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index a518270974f92..ba9bf58a4a20c 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional import numpy as np @@ -9,8 +9,9 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import is_list_of -from .base import MultiModalData, MultiModalKwargs +from .base import MultiModalData from .image import ImagePlugin +from .inputs import MultiModalKwargs, VideoItem if TYPE_CHECKING: from vllm.config import ModelConfig @@ -20,17 +21,6 @@ cached_get_video_processor = lru_cache(get_video_processor) cached_get_tokenizer = lru_cache(get_tokenizer) -VideoInput = Union[ - "np.ndarray", # single video input - List["np.ndarray"], - # TODO: support more types - # List[Image.Image], List[List[Image.Image]], - # "torch.Tensor", - # List["torch.Tensor"], - # List[List["np.ndarrray"]], - # List[List["torch.Tensor"]], -] - class VideoPlugin(ImagePlugin): """Plugin for video data.""" @@ -53,13 +43,13 @@ def _get_hf_video_processor( def _default_input_mapper( self, ctx: InputContext, - data: MultiModalData[object], + data: MultiModalData[VideoItem], **mm_processor_kwargs, ) -> MultiModalKwargs: model_config = ctx.model_config if isinstance(data, list) and len(data) == 1: - data = data[0] + data = data[0] # type: ignore if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): video_processor = self._get_hf_video_processor( From d0e2d45d80807a0b5e1a722734a6ea36f0a2f769 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 9 Nov 2024 13:09:43 +0000 Subject: [PATCH 07/16] Address comments Signed-off-by: DarkLight1337 --- vllm/inputs/preprocess.py | 15 ++++++++++++++- vllm/multimodal/processing.py | 6 ++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 15d2cdc64afa8..fdf28615fda10 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -203,9 +203,22 @@ async def _tokenize_prompt_async( lora_request=lora_request) def _can_process_multimodal(self) -> bool: + model_config = self.model_config + + if not model_config.is_multimodal_model: + raise ValueError("Your model does not support multi-modal inputs") + # Interim measure so we can handle models that have yet to be # updated to use the new multi-modal processor - return self.mm_registry.has_processor(self.model_config) + can_process_multimodal = self.mm_registry.has_processor(model_config) + if not can_process_multimodal: + logger.info( + "Your model uses the legacy input pipeline instead of the new " + "multi-modal processor. Please note that the legacy pipeline " + "will be removed in a future release. For more details, see: " + "https://github.com/vllm-project/vllm/issues/10114") + + return can_process_multimodal def _process_multimodal( self, diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 8d2e9987131ef..88a924da174a6 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -95,12 +95,10 @@ def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict: for k, v in data.items(): # yapf: disable - if k == "image": - multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] - elif k == "video": + if k == "video": # Special case since even a single item can be a list multi_data[k] = v if is_list_of(v, list) else [v] # type: ignore[index] - elif k == "audio": + elif k in ("image", "audio"): multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] else: multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] From 00e29b41da23fd357047648dc70b0ce44cab66af Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 06:17:42 +0000 Subject: [PATCH 08/16] Fix imports Signed-off-by: DarkLight1337 --- vllm/model_executor/models/llava.py | 2 +- vllm/model_executor/models/phi3v.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index af712bf8f9506..005ae5e03cfed 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import NestedTensors +from vllm.multimodal.inputs import NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index de03d28638cda..4db65edc174f1 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -39,7 +39,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import NestedTensors, PlaceholderRange +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import is_list_of diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 81480786a09e1..eebd1de96537f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -28,7 +28,7 @@ from vllm.v1.sample.metadata import SamplingMetadata if TYPE_CHECKING: - from vllm.multimodal.base import PlaceholderRange + from vllm.multimodal.inputs import PlaceholderRange from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) From e264c949516e8e0334c4e51d369ce1e5a8f10a14 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 06:57:05 +0000 Subject: [PATCH 09/16] Fix more tests Signed-off-by: DarkLight1337 --- tests/v1/core/test_prefix_caching.py | 4 ++-- vllm/v1/request.py | 12 ++++-------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index e5a3b62258dd8..d614d3e67460f 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,5 +1,5 @@ """Compare the with and without prefix caching.""" -from vllm.inputs import DecoderOnlyInputs +from vllm.inputs import token_inputs from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import hash_block_tokens @@ -8,7 +8,7 @@ def make_request(request_id, prompt_token_ids): return Request( request_id=request_id, - inputs=DecoderOnlyInputs(prompt_token_ids=prompt_token_ids), + inputs=token_inputs(prompt_token_ids=prompt_token_ids), sampling_params=SamplingParams(max_tokens=17), eos_token_id=100, arrival_time=0, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index f35cf738c89bf..e9b326aad3526 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,7 +1,7 @@ import enum -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union -from vllm.inputs.data import DecoderOnlyInputs +from vllm.inputs import DecoderOnlyInputs, token_inputs from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams @@ -9,16 +9,13 @@ from vllm.v1.engine import EngineCoreRequest from vllm.v1.utils import ConstantList -if TYPE_CHECKING: - from vllm.inputs import DecoderOnlyInputs - class Request: def __init__( self, request_id: str, - inputs: "DecoderOnlyInputs", + inputs: DecoderOnlyInputs, sampling_params: SamplingParams, eos_token_id: Optional[int], arrival_time: float, @@ -64,8 +61,7 @@ def __init__( def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( request_id=request.request_id, - inputs=DecoderOnlyInputs( - type="token", + inputs=token_inputs( prompt_token_ids=request.prompt_token_ids, prompt=request.prompt, multi_modal_data=request.mm_data, From 769614d8b87adbc66f6d4ffd735831eb1405081d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 07:09:39 +0000 Subject: [PATCH 10/16] Fix types Signed-off-by: DarkLight1337 --- vllm/inputs/registry.py | 2 +- vllm/multimodal/inputs.py | 12 +----------- vllm/multimodal/registry.py | 2 +- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 47edcfd8ed1c8..e1e474b57c67d 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -109,7 +109,7 @@ def __call__( ... -class _MultiModalCounts(UserDict): +class _MultiModalCounts(UserDict[str, int]): """ Wraps `mm_counts` for a more informative error message when attempting to access a plugin that does not exist. diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 3a005fd625708..64a4c58d5509c 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -1,4 +1,3 @@ -import sys from collections import UserDict, defaultdict from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple, TypedDict, TypeVar, Union, cast, final) @@ -108,17 +107,8 @@ class PlaceholderRange(TypedDict): :meth:`MultiModalKwargs.batch`. """ -if sys.version_info < (3, 9): - # UserDict cannot be subscripted - class _MultiModalKwargsBase(UserDict): - pass -else: - class _MultiModalKwargsBase(UserDict[str, NestedTensors]): - pass - - -class MultiModalKwargs(_MultiModalKwargsBase): +class MultiModalKwargs(UserDict[str, NestedTensors]): """ A dictionary that represents the keyword arguments to :meth:`~torch.nn.Module.forward`. diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 0c30e60d2e91c..b992442d3b314 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -33,7 +33,7 @@ """ -class _MultiModalLimits(UserDict): +class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): """ Wraps `_limits_by_model` for a more informative error message when attempting to access a model that does not exist. From fdc5b6b0862f6424d27b3eaf8636ee77abc41584 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 07:12:47 +0000 Subject: [PATCH 11/16] Fix types 2 Signed-off-by: DarkLight1337 --- vllm/multimodal/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index c76df1116fb79..6eec660e42ac4 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -118,7 +118,7 @@ def map_input( self, model_config: "ModelConfig", data: MultiModalData[Any], - mm_processor_kwargs: Dict[str, Any], + mm_processor_kwargs: Optional[Dict[str, Any]], ) -> MultiModalKwargs: """ Transform the data into a dictionary of model inputs using the @@ -145,6 +145,9 @@ def map_input( raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + # In the case of the default mapper, we have to get resource # processor through its HuggingFace autoclass; since this goes # through **kwargs, we can't inspect it the same way, so we allow From 73c0c0339cdddb8351028c94ce11b140277f9da9 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 08:09:09 +0000 Subject: [PATCH 12/16] Factor out common code for prompt extraction Signed-off-by: DarkLight1337 --- vllm/engine/llm_engine.py | 11 +---- vllm/inputs/__init__.py | 8 ++-- vllm/inputs/data.py | 86 ++++++++++++++++++++++++++++++++++--- vllm/inputs/registry.py | 37 ++++++++++++++-- vllm/sequence.py | 77 ++++++++------------------------- vllm/v1/engine/async_llm.py | 4 ++ vllm/v1/engine/processor.py | 68 +++++++++++++++++++++-------- 7 files changed, 193 insertions(+), 98 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5e9967cd89995..f5299746d845d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -30,7 +30,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, - PromptType) + PromptType, SingletonInputsAdapter) from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger @@ -854,13 +854,6 @@ def add_request( ) processed_inputs = self.input_processor(preprocessed_inputs) - # This is a bit of a hack - copy the mm_processor_kwargs that were - # used in the input processor to the processed output, since these - # kwargs are presumed to be immutable and the values should be aligned - # between the input processor (here) and the input mapper. - processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( - "mm_processor_kwargs") - self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, @@ -2022,7 +2015,7 @@ def _validate_model_inputs(self, inputs: ProcessorInputs, else: prompt_inputs = inputs - prompt_ids = prompt_inputs.get("prompt_token_ids") + prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 338589ed04dc4..54fbd7a321a6f 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,8 +1,9 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, - TokensPrompt, build_explicit_enc_dec_prompt, - to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) + SingletonInputs, SingletonInputsAdapter, SingletonPrompt, + TextPrompt, TokenInputs, TokensPrompt, + build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, + token_inputs, zip_enc_dec_prompts) from .registry import (DummyData, InputContext, InputProcessingContext, InputRegistry) @@ -27,6 +28,7 @@ "EncoderDecoderInputs", "ProcessorInputs", "SingletonInputs", + "SingletonInputsAdapter", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index ac1a425538c34..412580d1ea88d 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,7 +1,10 @@ +from dataclasses import dataclass +from functools import cached_property from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal, Optional, Tuple, Union, cast) -from typing_extensions import NotRequired, TypedDict, TypeVar +import torch +from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict @@ -205,6 +208,78 @@ class EncoderDecoderInputs(TypedDict): :class:`vllm.sequence.Sequence`. """ + +@dataclass +class SingletonInputsAdapter: + """ + Unified interface to access the components of :class:`SingletonInputs`. + """ + inputs: SingletonInputs + + @cached_property + def prompt(self) -> Optional[str]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("prompt") + + assert_never(inputs) + + @cached_property + def prompt_token_ids(self) -> List[int]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("prompt_token_ids", []) + + assert_never(inputs) + + @cached_property + def prompt_embeds(self) -> Optional[torch.Tensor]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return None + + assert_never(inputs) + + @cached_property + def multi_modal_data(self) -> "MultiModalDataDict": + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_data", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_kwargs", {}) + + assert_never(inputs) + + @cached_property + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_placeholders", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_placeholders", {}) + + assert_never(inputs) + + @cached_property + def mm_processor_kwargs(self) -> Dict[str, Any]: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("mm_processor_kwargs", {}) + + if inputs["type"] == "multimodal": + return {} + + assert_never(inputs) + + ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] """ The inputs to :data:`vllm.inputs.InputProcessor`. @@ -235,10 +310,11 @@ def zip_enc_dec_prompts( ) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ Zip encoder and decoder prompts together into a list of - :class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs - may also be provided; if a dict is passed, the same dictionary will be - used for every encoder/decoder prompt. If an iterable is provided, it will - be zipped with the encoder/decoder prompts. + :class:`ExplicitEncoderDecoderPrompt` instances. + + ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same + dictionary will be used for every encoder/decoder prompt. If an iterable is + provided, it will be zipped with the encoder/decoder prompts. """ if mm_processor_kwargs is None: mm_processor_kwargs = cast(Dict[str, Any], {}) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index e1e474b57c67d..68b4756331e6d 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -6,7 +6,7 @@ from torch import nn from transformers import PretrainedConfig, ProcessorMixin -from typing_extensions import TypeVar +from typing_extensions import TypeVar, assert_never from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_get_processor @@ -14,7 +14,8 @@ from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, resolve_mm_processor_kwargs) -from .data import ProcessorInputs +from .data import ProcessorInputs, SingletonInputs +from .parse import is_encoder_decoder_inputs if TYPE_CHECKING: from vllm.config import ModelConfig @@ -302,6 +303,21 @@ def _get_model_input_processor(self, model_cls: Type[nn.Module]): return self._input_processors_by_model_type \ .get(model_cls, self._default_input_processor) + def _ensure_mm_kwargs( + self, + inputs: SingletonInputs, + mm_processor_kwargs: Dict[str, Any], + ): + if inputs["type"] == "token": + # In case the input processor for that model fails to set it + if "mm_processor_kwargs" not in inputs: + inputs["mm_processor_kwargs"] = mm_processor_kwargs + elif inputs["type"] == "multimodal": + # Be more strict in V2 + assert "mm_kwargs" in inputs + else: + assert_never(inputs["type"]) + def process_input(self, model_config: "ModelConfig", inputs: ProcessorInputs) -> ProcessorInputs: """ @@ -327,8 +343,21 @@ def process_input(self, model_config: "ModelConfig", processor, ) - return processor(InputContext(model_config), inputs, - **mm_processor_kwargs) + processed_inputs = processor( + InputContext(model_config), + inputs, + **mm_processor_kwargs, + ) + + if is_encoder_decoder_inputs(processed_inputs): + self._ensure_mm_kwargs(processed_inputs["encoder"], + mm_processor_kwargs) + self._ensure_mm_kwargs(processed_inputs["decoder"], + mm_processor_kwargs) + else: + self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs) + + return processed_inputs def create_input_processor(self, model_config: "ModelConfig"): """ diff --git a/vllm/sequence.py b/vllm/sequence.py index 46d18d6e4d862..3b41d25a2fe42 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,25 +5,21 @@ from array import array from collections import defaultdict from dataclasses import dataclass, field -from functools import cached_property, reduce -from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, - Mapping, Optional) +from functools import reduce +from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence from typing import Set, Tuple, Union import msgspec import torch -from typing_extensions import assert_never +from vllm.inputs import SingletonInputs, SingletonInputsAdapter from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams -if TYPE_CHECKING: - from vllm.inputs import SingletonInputs - VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_INVALID_TOKEN_ID = -1 @@ -407,14 +403,14 @@ class Sequence: def __init__( self, seq_id: int, - inputs: "SingletonInputs", + inputs: SingletonInputs, block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.seq_id = seq_id - self.inputs = inputs + self.inputs = SingletonInputsAdapter(inputs) self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request @@ -441,68 +437,29 @@ def __init__( def n_blocks(self) -> int: return (self.get_len() + self.block_size - 1) // self.block_size - @cached_property + @property def prompt(self) -> Optional[str]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("prompt") + return self.inputs.prompt - assert_never(inputs) - - @cached_property + @property def prompt_token_ids(self) -> List[int]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return inputs.get("prompt_token_ids", []) + return self.inputs.prompt_token_ids - assert_never(inputs) - - @cached_property + @property def prompt_embeds(self) -> Optional[torch.Tensor]: - inputs = self.inputs - - if inputs["type"] == "token" or inputs["type"] == "multimodal": - return None - - assert_never(inputs) + return self.inputs.prompt_embeds - @cached_property + @property def multi_modal_data(self) -> "MultiModalDataDict": - inputs = self.inputs + return self.inputs.multi_modal_data - if inputs["type"] == "token": - return inputs.get("multi_modal_data", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_kwargs", {}) - - assert_never(inputs) - - @cached_property + @property def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("multi_modal_placeholders", {}) - - if inputs["type"] == "multimodal": - return inputs.get("mm_placeholders", {}) + return self.inputs.multi_modal_placeholders - assert_never(inputs) - - @cached_property + @property def mm_processor_kwargs(self) -> Dict[str, Any]: - inputs = self.inputs - - if inputs["type"] == "token": - return inputs.get("mm_processor_kwargs", {}) - - if inputs["type"] == "multimodal": - return {} - - assert_never(inputs) + return self.inputs.mm_processor_kwargs @property def lora_int_id(self) -> int: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 2d7c58cfea13b..09bff9655a882 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -6,6 +6,7 @@ from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.protocol import EngineClient from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -321,6 +322,9 @@ async def get_model_config(self) -> ModelConfig: async def get_decoding_config(self): raise ValueError("Not Supported on V1 yet.") + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.processor.input_preprocessor + async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index a0c434c0e886e..5c1577190c75a 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -2,8 +2,9 @@ from typing import Any, Dict, Mapping, Optional, Tuple, Union from vllm.config import LoRAConfig, ModelConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderInputs, InputRegistry, PromptType) +from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, + PromptType, SingletonInputsAdapter) +from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry @@ -11,7 +12,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.config import try_get_generation_config -from vllm.transformers_utils.tokenizer_group import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest @@ -21,7 +22,7 @@ def __init__( self, model_config: ModelConfig, lora_config: Optional[LoRAConfig], - tokenizer: AnyTokenizer, + tokenizer: BaseTokenizerGroup, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): @@ -76,6 +77,19 @@ def process_inputs( self._validate_model_inputs(processed_inputs) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + if is_encoder_decoder_inputs(processed_inputs): + decoder_inputs = SingletonInputsAdapter( + processed_inputs["decoder"]) + encoder_inputs = SingletonInputsAdapter( + processed_inputs["encoder"]) + else: + decoder_inputs = SingletonInputsAdapter(processed_inputs) + encoder_inputs = None + + # TODO: Impl encoder-decoder + if encoder_inputs is not None: + raise NotImplementedError + assert isinstance(params, SamplingParams) # TODO: can we avoid cloning here in multiproc case sampling_params = params.clone() @@ -84,27 +98,43 @@ def process_inputs( # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest( - request_id, processed_inputs.get("prompt"), - processed_inputs.get("prompt_token_ids"), + request_id, + decoder_inputs.prompt, + decoder_inputs.prompt_token_ids, sampling_params.skip_special_tokens, sampling_params.spaces_between_special_tokens, - sampling_params.output_kind, sampling_params.stop, - sampling_params.include_stop_str_in_output) + sampling_params.output_kind, + sampling_params.stop, + sampling_params.include_stop_str_in_output, + ) # Make Request for EngineCore. engine_core_request = EngineCoreRequest( - request_id, processed_inputs.get("prompt"), - processed_inputs.get("prompt_token_ids"), - processed_inputs.get("multi_modal_data"), - processed_inputs.get("multi_modal_placeholders"), - processed_inputs.get("mm_processor_kwargs"), sampling_params, - eos_token_id, arrival_time, lora_request) + request_id, + decoder_inputs.prompt, + decoder_inputs.prompt_token_ids, + decoder_inputs.multi_modal_data, + decoder_inputs.multi_modal_placeholders, + decoder_inputs.mm_processor_kwargs, + sampling_params, + eos_token_id, + arrival_time, + lora_request, + ) return detokenizer_request, engine_core_request - def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderInputs]): - prompt_ids = inputs.get("prompt_token_ids") + def _validate_model_inputs(self, inputs: ProcessorInputs): + if is_encoder_decoder_inputs(inputs): + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + prompt_inputs = inputs["decoder" if self.model_config. + is_multimodal_model else "encoder"] + else: + prompt_inputs = inputs + + prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids + if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") @@ -120,6 +150,10 @@ def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, "inputs, the number of image tokens depends on the number " "of images, and possibly their aspect ratios as well.") + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens + def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: config = try_get_generation_config( From fb9c54bbbcd249e8fc1f012900bab97be9a061ab Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 08:14:43 +0000 Subject: [PATCH 13/16] Fix import Signed-off-by: DarkLight1337 --- vllm/inputs/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 412580d1ea88d..07ff9faa50f13 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -256,7 +256,7 @@ def multi_modal_data(self) -> "MultiModalDataDict": assert_never(inputs) @cached_property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": inputs = self.inputs if inputs["type"] == "token": From 10e881f0679394b08363b43342edf3c4517bef66 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 09:37:56 +0000 Subject: [PATCH 14/16] Fix tests Signed-off-by: DarkLight1337 --- tests/multimodal/test_processor_kwargs.py | 37 ++++++++++++++--------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 4d3bbd805c152..e6c8793989e13 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -1,12 +1,12 @@ from array import array -from typing import Mapping +from typing import Callable, Dict, Mapping, Optional from unittest.mock import patch import pytest import torch from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext, - InputRegistry, token_inputs) + InputRegistry, ProcessorInputs, token_inputs) from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -34,10 +34,9 @@ def custom_processor(ctx: InputContext, inputs: DecoderOnlyInputs, *, num_crops=DEFAULT_NUM_CROPS): - # For testing purposes, we don't worry about the llm inputs / return - # type validation, and just return the value of the kwarg that we - # clobber. - return num_crops + # For testing purposes, we don't worry about the prompt + return token_inputs(prompt_token_ids=[], + mm_processor_kwargs={"num_crops": num_crops}) with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", return_value=custom_processor): @@ -109,6 +108,21 @@ def _get_num_crops_info(init_num_crops: int, inference_num_crops: int): return init_kwargs, inference_kwargs, expected_seq_count +def _get_processed_num_crops( + processor: Callable[[ProcessorInputs], ProcessorInputs], + inference_kwargs: Optional[Dict[str, int]], +) -> int: + processed_inputs = processor( + token_inputs(prompt_token_ids=[], + prompt="", + mm_processor_kwargs=inference_kwargs)) + + assert "type" in processed_inputs + assert processed_inputs["type"] == "token" + assert "mm_processor_kwargs" in processed_inputs + return processed_inputs["mm_processor_kwargs"]["num_crops"] + + @pytest.mark.parametrize("init_num_crops,inference_num_crops", [ (None, None), (NUM_CROPS_OVERRIDE, None), @@ -124,10 +138,8 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops, ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) - num_crops_val = processor( - token_inputs(prompt_token_ids=[], - prompt="", - mm_processor_kwargs=inference_kwargs)) + num_crops_val = _get_processed_num_crops(processor, inference_kwargs) + assert num_crops_val == expected_seq_count @@ -153,10 +165,7 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor = dummy_registry.create_input_processor(ctx.model_config) # Should filter out the inference time kwargs - num_crops_val = processor( - token_inputs(prompt_token_ids=[], - prompt="", - mm_processor_kwargs=mm_processor_kwargs)) + num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs) assert num_crops_val == DEFAULT_NUM_CROPS From fd66e86b2f6d136ca4d3409364b893d0a9886deb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 10:38:39 +0000 Subject: [PATCH 15/16] Fix request Signed-off-by: DarkLight1337 --- vllm/v1/request.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index e9b326aad3526..b09d935067d70 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,7 +1,7 @@ import enum from typing import List, Optional, Union -from vllm.inputs import DecoderOnlyInputs, token_inputs +from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams @@ -22,7 +22,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id - self.inputs = inputs + self.inputs = SingletonInputsAdapter(inputs) self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id @@ -38,17 +38,17 @@ def __init__( assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens - self.prompt = inputs.get("prompt") - self.prompt_token_ids = inputs["prompt_token_ids"] + self.prompt = self.inputs.prompt + self.prompt_token_ids = self.inputs.prompt_token_ids self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: List[int] = [] self._all_token_ids: List[int] = self.prompt_token_ids.copy() self.num_computed_tokens = 0 # Raw multimodal data before the mm input mapper (e.g., PIL images). - self.mm_data = inputs.get("multi_modal_data") - self.mm_processor_kwargs = inputs.get("mm_processor_kwargs") - mm_positions = inputs.get("multi_modal_placeholders") + self.mm_data = self.inputs.multi_modal_data + self.mm_processor_kwargs = self.inputs.mm_processor_kwargs + mm_positions = self.inputs.multi_modal_placeholders if mm_positions: # FIXME(woosuk): Support other modalities. self.mm_positions = mm_positions.get("image", []) From 25d19801a7fdabee66d9ca0197c25dc6f0f20f71 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 11:08:18 +0000 Subject: [PATCH 16/16] Fix request Signed-off-by: DarkLight1337 --- vllm/v1/request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index b09d935067d70..51fb4003e5fe0 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -110,7 +110,7 @@ def get_finished_reason(self) -> Union[str, None]: return RequestStatus.get_finished_reason(self.status) def has_encoder_inputs(self) -> bool: - return self.mm_data is not None + return len(self.mm_data) > 0 @property def num_encoder_inputs(self) -> int: