From 101418096ffe3c83b6d541e1303b10e9d5e03861 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 28 Dec 2024 01:22:48 +0800 Subject: [PATCH] [VLM] Support caching in merged multi-modal processor (#11396) Signed-off-by: DarkLight1337 --- docs/source/conf.py | 3 +- .../design/multimodal/multimodal_index.md | 24 +- docs/source/models/supported_models.md | 3 +- .../openai/test_vision_embedding.py | 4 +- .../mm_processor_kwargs/test_qwen2_vl.py | 2 +- .../vision_language/test_models.py | 4 +- tests/multimodal/test_processing.py | 209 ++++++- vllm/inputs/registry.py | 22 +- vllm/model_executor/models/llava.py | 178 +++--- vllm/model_executor/models/phi3v.py | 107 +++- vllm/model_executor/models/qwen.py | 4 +- vllm/model_executor/models/qwen2_audio.py | 65 ++- vllm/model_executor/models/qwen2_vl.py | 115 ++-- vllm/model_executor/models/ultravox.py | 76 ++- vllm/multimodal/base.py | 44 +- vllm/multimodal/inputs.py | 438 ++++++++++++++- vllm/multimodal/processing.py | 516 ++++++++++++------ vllm/multimodal/registry.py | 50 +- vllm/transformers_utils/processor.py | 12 +- vllm/utils.py | 27 +- 20 files changed, 1455 insertions(+), 448 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1fe0474631140..71394c5302a39 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -191,6 +191,7 @@ def linkcode_resolve(domain, info): # Mock out external dependencies here, otherwise the autodoc pages may be blank. autodoc_mock_imports = [ + "blake3", "compressed_tensors", "cpuinfo", "cv2", @@ -207,7 +208,7 @@ def linkcode_resolve(domain, info): "tensorizer", "pynvml", "outlines", - "xgrammar," + "xgrammar", "librosa", "soundfile", "gguf", diff --git a/docs/source/design/multimodal/multimodal_index.md b/docs/source/design/multimodal/multimodal_index.md index 88af07afc7018..e4f2171e84ff7 100644 --- a/docs/source/design/multimodal/multimodal_index.md +++ b/docs/source/design/multimodal/multimodal_index.md @@ -45,39 +45,39 @@ adding_multimodal_plugin ### Base Classes ```{eval-rst} -.. autodata:: vllm.multimodal.NestedTensors +.. automodule:: vllm.multimodal.base + :members: + :show-inheritance: ``` -```{eval-rst} -.. autodata:: vllm.multimodal.BatchedTensorInputs -``` +### Input Classes ```{eval-rst} -.. autoclass:: vllm.multimodal.MultiModalDataBuiltins +.. automodule:: vllm.multimodal.inputs :members: :show-inheritance: ``` -```{eval-rst} -.. autodata:: vllm.multimodal.MultiModalDataDict -``` +### Audio Classes ```{eval-rst} -.. autoclass:: vllm.multimodal.MultiModalKwargs +.. automodule:: vllm.multimodal.audio :members: :show-inheritance: ``` +### Image Classes + ```{eval-rst} -.. autoclass:: vllm.multimodal.MultiModalPlugin +.. automodule:: vllm.multimodal.image :members: :show-inheritance: ``` -### Image Classes +### Video Classes ```{eval-rst} -.. automodule:: vllm.multimodal.image +.. automodule:: vllm.multimodal.video :members: :show-inheritance: ``` diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 95add0d71bbab..7acafda50793c 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -755,8 +755,7 @@ vLLM currently only supports adding LoRA to the language backbone of multimodal ``` ```{note} -To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo ({code}`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`) -and pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. +To use {code}`TIGER-Lab/Mantis-8B-siglip-llama3`, you have pass {code}`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. ``` ```{note} diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py index 3731b2dcdeae1..c851539c610ec 100644 --- a/tests/entrypoints/openai/test_vision_embedding.py +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -91,5 +91,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, assert len(embeddings.data) == 1 assert len(embeddings.data[0].embedding) == 3072 assert embeddings.usage.completion_tokens == 0 - assert embeddings.usage.prompt_tokens == 765 - assert embeddings.usage.total_tokens == 765 + assert embeddings.usage.prompt_tokens == 764 + assert embeddings.usage.total_tokens == 764 diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py index cd8954ffc48c2..5897c04c89e19 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py @@ -30,7 +30,7 @@ def get_max_qwen2_vl_image_tokens(): @pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [ - ({}, 1225), + ({}, 16384), ({ MIN_PIXELS: 64**2, MAX_PIXELS: 512**2 diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 3101d1d2ea831..1a9c1b4ef1be0 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -201,6 +201,7 @@ vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output, num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + marks=[large_gpu_mark(min_gb=48)], ), "glm4": VLMTestInfo( models=["THUDM/glm-4v-9b"], @@ -212,7 +213,7 @@ dtype="bfloat16", get_stop_token_ids=lambda tok: [151329, 151336, 151338], patch_hf_runner=model_utils.glm_patch_hf_runner, - marks=[large_gpu_mark(min_gb=48)], + marks=[large_gpu_mark(min_gb=32)], ), "h2ovl": VLMTestInfo( models = [ @@ -261,6 +262,7 @@ dtype="bfloat16", use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, + marks=[large_gpu_mark(min_gb=32)], ), "llava_next": VLMTestInfo( models=["llava-hf/llava-v1.6-mistral-7b-hf"], diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index d22d778f81fa8..1b2847ed0f534 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -1,12 +1,20 @@ +from functools import partial from typing import cast +import numpy as np import pytest - -from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo, - find_text_matches, find_token_matches, - iter_placeholders, iter_token_matches, +from PIL import Image + +from vllm.config import ModelConfig +from vllm.inputs import InputProcessingContext +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, + _PlaceholderInfo, find_text_matches, + find_token_matches, iter_placeholders, + iter_token_matches, replace_text_matches, replace_token_matches) +from vllm.multimodal.utils import cached_get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import full_groupby @@ -457,6 +465,7 @@ def test_find_replace_tokens( ), ] ) +# yapf: enable def test_iter_placeholders( repl_by_key, prompt, @@ -475,11 +484,199 @@ def test_iter_placeholders( prompt_repls, prompt, # Effectively match all occurrences in the prompt - {key: 3 for key in repl_by_key}, - )) + {key: 3 + for key in repl_by_key}, + )) # Only displayed on error print("result:", result) # Manually constructed results assert result == expected + + +def _rand_img(rng: np.random.RandomState, min_wh: int, max_wh: int): + w, h = rng.randint(min_wh, max_wh, size=(2, )) + arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8) + return Image.fromarray(arr) + + +def _rand_video( + rng: np.random.RandomState, + min_frames: int, + max_frames: int, + min_wh: int, + max_wh: int, +): + # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 + num_frames = rng.randint(min_frames, max_frames) + num_frames = (num_frames // 2) * 2 + + w, h = rng.randint(min_wh, max_wh, size=(2, )) + return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8) + + +def _rand_audio( + rng: np.random.RandomState, + min_len: int, + max_len: int, + sr: int, +): + audio_len = rng.randint(min_len, max_len) + return rng.rand(audio_len), sr + + +def _test_processing_cache_correctness( + model_id: str, + modalities: set[str], + hit_rate: float, + num_batches: int, + simplify_rate: float, +): + if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3": + hf_overrides = {"architectures": ["MantisForConditionalGeneration"]} + else: + hf_overrides = {} + + model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=True, + seed=0, + dtype="float16", + revision=None, + hf_overrides=hf_overrides, + ) + model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) + + processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] + ctx = InputProcessingContext( + model_config, + tokenizer=cached_get_tokenizer(model_config.tokenizer), + ) + # Ensure that it can fit all of the data + cache = ProcessingCache(capacity=1 << 30) + + baseline_processor = processor_factory(ctx, cache=None) + cached_processor = processor_factory(ctx, cache=cache) + + rng = np.random.RandomState(0) + + input_to_hit = { + "image": Image.new("RGB", size=(128, 128)), + "video": np.zeros((4, 128, 128, 3), dtype=np.uint8), + "audio": (np.zeros((512, )), 16000), + } + input_factory = { + "image": + partial(_rand_img, rng, min_wh=128, max_wh=256), + "video": + partial(_rand_video, + rng, + min_frames=2, + max_frames=8, + min_wh=128, + max_wh=256), + "audio": + partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000), + } + input_max_count = { + "image": 3, + "video": 3, + "audio": 3, + } + + for batch_idx in range(num_batches): + mm_data = { + k: + [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) + for _ in range(rng.randint(input_max_count[k]))] + for k in modalities + } + + mm_counts = {k: len(vs) for k, vs in mm_data.items()} + prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text + + # Drop unnecessary keys and test single -> multi conversion + if rng.rand() < simplify_rate: + for k in list(mm_data.keys()): + if not mm_data[k]: + del mm_data[k] + elif len(mm_data[k]) == 1: + mm_data[k] = mm_data[k][0] + + baseline_result = baseline_processor.apply( + prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + cached_result = cached_processor.apply( + prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert baseline_result == cached_result, ( + f"Failed ({batch_idx=}, {mm_data=})") + + +# yapf: disable +@pytest.mark.parametrize(("model_id", "modalities"), [ + ("llava-hf/llava-1.5-7b-hf", {"image"}), + ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image"}), + ("mistral-community/pixtral-12b", {"image"}), + ("Qwen/Qwen2-VL-2B-Instruct", {"image", "video"}), + ("Qwen/Qwen2-Audio-7B-Instruct", {"audio"}), + ("fixie-ai/ultravox-v0_3", {"audio"}), +]) +@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) +@pytest.mark.parametrize("num_batches", [32]) +@pytest.mark.parametrize("simplify_rate", [1.0]) +# yapf: enable +def test_processing_cache_correctness( + model_id: str, + modalities: set[str], + hit_rate: float, + num_batches: int, + simplify_rate: float, +): + _test_processing_cache_correctness( + model_id, + modalities, + hit_rate=hit_rate, + num_batches=num_batches, + simplify_rate=simplify_rate, + ) + + +# yapf: disable +@pytest.mark.parametrize(("model_id", "modalities"), [ + ("microsoft/Phi-3-vision-128k-instruct", {"image"}), +]) +@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) +@pytest.mark.parametrize("num_batches", [32]) +@pytest.mark.parametrize("simplify_rate", [1.0]) +# yapf: enable +def test_processing_cache_correctness_phi3v( + model_id: str, + modalities: set[str], + hit_rate: float, + num_batches: int, + simplify_rate: float, +): + # HACK - this is an attempted workaround for the following bug + # https://github.com/huggingface/transformers/issues/34307 + from transformers import AutoImageProcessor # noqa: F401 + from transformers import AutoProcessor # noqa: F401 + + AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) + + _test_processing_cache_correctness( + model_id, + modalities, + hit_rate=hit_rate, + num_batches=num_batches, + simplify_rate=simplify_rate, + ) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index f3ec9d115c9ba..46346b08e99c2 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -99,6 +99,9 @@ def get_hf_processor( merged_kwargs = {**base_kwargs, **kwargs} + if isinstance(typ, type): + merged_kwargs["processor_cls"] = typ + hf_processor = cached_get_processor( self.model_config.model, trust_remote_code=self.model_config.trust_remote_code, @@ -132,10 +135,13 @@ def get_hf_processor( def call_hf_processor( self, hf_processor: ProcessorMixin, - prompt: str, - processor_data: Mapping[str, object], - inference_kwargs: Mapping[str, object], + data: Mapping[str, object], + kwargs: Mapping[str, object] = {}, ) -> BatchFeature: + """ + Call :code:`hf_processor` on the prompt :code:`data` + (text, image, audio...) with configurable options :code:`kwargs`. + """ assert callable(hf_processor) base_kwargs = self.model_config.mm_processor_kwargs @@ -144,21 +150,15 @@ def call_hf_processor( merged_kwargs = resolve_mm_processor_kwargs( base_kwargs, - inference_kwargs, + kwargs, hf_processor, requires_kw_only=False, allow_var_kwargs=True, ) try: - return hf_processor( - text=prompt, - **processor_data, - **merged_kwargs, - return_tensors="pt", - ) + return hf_processor(**data, **merged_kwargs, return_tensors="pt") except Exception as exc: - data = dict(text=prompt, **processor_data) msg = (f"Failed to apply {type(hf_processor).__name__} " f"on data={data} with kwargs={merged_kwargs}") diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 0662d90e79b92..0ecba5a1cae0f 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,5 +1,4 @@ from functools import cached_property -from types import MethodType from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) @@ -7,7 +6,7 @@ import torch.nn as nn from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, PretrainedConfig, - ProcessorMixin, SiglipVisionConfig) + SiglipVisionConfig) from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor @@ -21,10 +20,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, + MultiModalFieldConfig, MultiModalInputsV2, + MultiModalKwargs, NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement) + ProcessorInputs, PromptReplacement, + full_groupby_modality) from vllm.sequence import IntermediateTensors from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -116,36 +117,54 @@ def get_max_llava_image_tokens(ctx: InputContext): class LlavaMultiModalProcessor(BaseMultiModalProcessor): - def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): - if getattr(hf_processor, "__is_patched__", False): - return # Already patched - - image_processor = hf_processor.image_processor # type: ignore - orig_preprocess = image_processor.preprocess + def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: + return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor)) - def preprocess(__self, *args, **kwargs): - hf_inputs = orig_preprocess(*args, **kwargs) - hf_inputs["is_pixtral"] = torch.tensor(True) - return hf_inputs + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) - image_processor.preprocess = MethodType(preprocess, image_processor) + # NOTE: pixel_values=None for MLlavaProcessor + pixel_values = processed_outputs.get("pixel_values") + if pixel_values is not None: + images = mm_data["images"] + assert isinstance(images, list) - hf_processor.__is_patched__ = True # type: ignore + if isinstance(self._get_hf_processor(), PixtralProcessor): + # Original output: (1, num_images, C, H, W) + # New output: (num_images, C, H, W) + assert (isinstance(pixel_values, list) + and len(pixel_values) == 1 + and isinstance(pixel_values[0], list) + and len(pixel_values[0]) == len(images)) - def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: - hf_processor = self.ctx.get_hf_processor( - (LlavaProcessor, PixtralProcessor)) + processed_outputs["pixel_values"] = pixel_values[0] - if isinstance(hf_processor, PixtralProcessor): - self._patch_pixtral_processor(hf_processor) + return processed_outputs - return hf_processor + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_config = self.ctx.get_hf_config(LlavaConfig) image_token_id = hf_config.image_token_index @@ -200,7 +219,7 @@ def _get_dummy_mm_inputs( ) -> ProcessorInputs: hf_config = self.ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config - num_images = mm_counts["image"] + num_images = mm_counts.get("image", 0) if isinstance(vision_config, CLIPVisionConfig): data = dummy_image_for_clip(vision_config, num_images) @@ -218,7 +237,6 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text=image_token * num_images, mm_data=data, - mm_processor_kwargs={}, ) @@ -379,7 +397,6 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False])) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: @@ -390,33 +407,6 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - assert isinstance(is_pixtral, torch.Tensor) - if is_pixtral.any(): - images = pixel_values - - def flatten_to_3d_tensors(item): - if isinstance(item, torch.Tensor): - if item.dim() >= 3: - return [t for t in item.view(-1, *item.shape[-3:])] - else: - raise ValueError( - f"Unexpected tensor dimension: {item.dim()}") - elif isinstance(item, list): - return [ - t for subitem in item - for t in flatten_to_3d_tensors(subitem) - ] - else: - raise ValueError(f"Unexpected type: {type(item)}") - - # Restructure the batched images into a list of lists of images - images = flatten_to_3d_tensors(pixel_values) - - return LlavaImagePixelInputs( - type="pixel_values", - data=images, - ) - return LlavaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( @@ -586,19 +576,71 @@ def load_weights(self, weights: Iterable[Tuple[str, class MantisMultiModalProcessor(LlavaMultiModalProcessor): - def _get_hf_processor(self) -> ProcessorMixin: - try: - from mantis.models.mllava import MLlavaProcessor - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "You need to `pip install " - "git+https://github.com/TIGER-AI-Lab/Mantis.git` " - "to use this model") from exc - - processor = MLlavaProcessor.from_pretrained( - self.ctx.model_config.tokenizer) - assert isinstance(processor, ProcessorMixin) - return processor + def _get_hf_processor(self): + return self.ctx.get_hf_processor(LlavaProcessor) + + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + hf_config = self.ctx.get_hf_config(LlavaConfig) + image_token_id = hf_config.image_token_index + max_image_tokens = get_max_llava_image_tokens(self.ctx) + + result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + + mm_items = self._get_mm_items(mm_data) + mm_item_counts = mm_items.get_item_counts() + mm_kwargs = result["mm_kwargs"] + + # We reimplement the functionality of MLlavaProcessor from + # https://github.com/TIGER-AI-Lab/Mantis.git + def get_replacement_mantis(item_idx: int): + return "".join([ + f"(image {item_idx+1}: ", # 7 tokens + "" * max_image_tokens, + ")", # 3 tokens + ]) + + mantis_repls = self._bind_prompt_replacements([ + PromptReplacement( + modality="image", + target=[image_token_id] * max_image_tokens, + replacement=get_replacement_mantis, + ) + ]) + + prompt_ids, prompt_text, _ = self._apply_prompt_replacements( + result["prompt_token_ids"], + mantis_repls, + mm_item_counts, + ) + + unbound_orig_repls = self._get_prompt_replacements( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + orig_repls = self._bind_prompt_replacements(unbound_orig_repls) + + all_placeholders = self._find_placeholders(orig_repls, prompt_ids, + mm_item_counts) + assert len(all_placeholders) == mm_item_counts.get("image", 0) + + mm_placeholders = { + modality: [item.to_range() for item in items] + for modality, items in full_groupby_modality(all_placeholders) + } + + return MultiModalInputsV2( + type="multimodal", + prompt=prompt_text, + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_placeholders=mm_placeholders, + ) # To use this model, please use diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4e2e7f5761544..fefa9fd62d1d0 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -12,9 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -32,10 +32,14 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, + MultiModalFieldConfig, MultiModalInputsV2, + MultiModalKwargs, NestedTensors, + PlaceholderRange) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement) + ProcessorInputs, PromptReplacement, + _BoundPromptReplacement, + _PlaceholderInfo) from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -306,11 +310,11 @@ def get_max_phi3v_image_tokens( *, num_crops: Optional[int] = None, ) -> int: - mm_processor_kwargs = {} + hf_processor_mm_kwargs = {} if num_crops: - mm_processor_kwargs["num_crops"] = num_crops + hf_processor_mm_kwargs["num_crops"] = num_crops - processor = ctx.get_hf_processor(**mm_processor_kwargs) + processor = ctx.get_hf_processor(**hf_processor_mm_kwargs) return processor.calc_num_image_tokens_from_image_size( width=MAX_IMAGE_FEATURE_SIZE_WIDTH, @@ -331,39 +335,50 @@ def _get_hf_processor( def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: processed_outputs = super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) + input_ids = processed_outputs["input_ids"] + assert isinstance(input_ids, torch.Tensor) + # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids, # which will cause OverflowError when decoding the prompt_ids. # Therefore, we need to do an early replacement here - token_ids = processed_outputs['input_ids'] - token_ids[token_ids < 0] = _IMAGE_TOKEN_ID - processed_outputs['input_ids'] = token_ids + input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID) return processed_outputs + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() image_tokens: list[str] = hf_processor.img_tokens # type: ignore image_processor = hf_processor.image_processor # type: ignore - mm_config = self.ctx.get_mm_config() - max_images = mm_config.limit_per_prompt.get("image", 1) + tokenizer = self._get_tokenizer() + bos_token_id = tokenizer.bos_token_id + assert isinstance(bos_token_id, int) def get_replacement_phi3v(item_idx: int): image_size = mm_items.get_image_size(item_idx) @@ -372,21 +387,44 @@ def get_replacement_phi3v(item_idx: int): height=image_size.height, ) - return [_IMAGE_TOKEN_ID] * num_tokens + return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id] return [ PromptReplacement( modality="image", target=image_token, replacement=get_replacement_phi3v, - ) for image_token in image_tokens[:max_images] + ) for image_token in image_tokens[:len(mm_items.images)] ] + def _apply_prompt_replacements( + self, + token_ids: list[int], + prompt_repls: Sequence[_BoundPromptReplacement], + mm_item_counts: Mapping[str, int], + ) -> tuple[list[int], str, list[_PlaceholderInfo]]: + token_ids, text, placeholders = super()._apply_prompt_replacements( + token_ids=token_ids, + prompt_repls=prompt_repls, + mm_item_counts=mm_item_counts, + ) + + # Keep the behavior in line with HF processor + if text.startswith(" <|image|>"): + text = text.replace(" <|image|>", "<|image|>", 1) + token_ids = [token_ids[0], *token_ids[2:]] + placeholders = [ + _PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement) + for p in placeholders + ] + + return token_ids, text, placeholders + def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - num_images = mm_counts["image"] + num_images = mm_counts.get("image", 0) data = dummy_image_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, @@ -401,9 +439,28 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text="".join(image_tokens[:num_images]), mm_data=data, - mm_processor_kwargs={}, ) + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) + + # Only <|image|> tokens should be considered as placeholders, + # so we ignore the trailing bos_token_id + result["mm_placeholders"] = { + modality: [ + PlaceholderRange(offset=p["offset"], length=p["length"] - 1) + for p in ps + ] + for modality, ps in result["mm_placeholders"].items() + } + + return result + @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 63d1374ab4092..baf955f6b515d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -225,7 +225,7 @@ def __init__( d_model: int, n_head: int, mlp_ratio: float = 4.0, - norm_layer: Callable = nn.LayerNorm, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -266,7 +266,7 @@ def __init__( layers: int, heads: int, mlp_ratio: float = 4.0, - norm_layer: Callable = nn.LayerNorm, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 6259166a7fc57..25a351bd9c656 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -26,7 +26,7 @@ import numpy as np import torch import torch.nn as nn -from transformers import BatchFeature, ProcessorMixin +from transformers import BatchFeature from transformers.models.qwen2_audio import (Qwen2AudioConfig, Qwen2AudioEncoder, Qwen2AudioProcessor) @@ -38,10 +38,10 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement) + ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -73,7 +73,7 @@ def forward(self, audio_features): # From Qwen2AudioEncoder._get_feat_extract_output_lengths -def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): +def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): feat_lengths = (input_lengths - 1) // 2 + 1 output_lengths = (feat_lengths - 2) // 2 + 1 return feat_lengths, output_lengths @@ -88,13 +88,18 @@ def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int: class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): - def _get_hf_processor(self) -> Qwen2AudioProcessor: + def _get_hf_processor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor) def _get_feature_extractor(self) -> WhisperFeatureExtractor: return self._get_hf_processor().feature_extractor # type: ignore - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -102,50 +107,61 @@ def _get_processor_data( feature_extractor = self._get_feature_extractor() mm_items.resample_audios(feature_extractor.sampling_rate) - return super()._get_processor_data(mm_items) + return super()._get_hf_mm_data(mm_items) def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: - processor_data = dict(processor_data) - audios = processor_data.pop("audios", []) + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) if audios: - processor_data["audios"] = audios + mm_data["audios"] = audios feature_extractor = self._get_feature_extractor() - mm_processor_kwargs = dict( - **mm_processor_kwargs, + mm_kwargs = dict( + **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, ) else: # NOTE: WhisperFeatureExtractor cannot handle empty list of audios pass - return super()._call_hf_processor( - hf_processor, + processed_outputs = super()._call_hf_processor( prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), ) def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) placeholder = hf_config.audio_token_index - feature_attention_mask = hf_inputs.get("feature_attention_mask") + feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") if feature_attention_mask is None: audio_output_lengths = [] else: + assert isinstance(feature_attention_mask, torch.Tensor) _, audio_output_lengths = _get_feat_extract_output_lengths( feature_attention_mask.sum(-1)) @@ -168,14 +184,13 @@ def _get_dummy_mm_inputs( sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate - audio_count = mm_counts["audio"] + audio_count = mm_counts.get("audio", 0) audio = np.zeros(audio_len) data = {"audio": [audio] * audio_count} return ProcessorInputs( prompt_text="<|AUDIO|>" * audio_count, mm_data=data, - mm_processor_kwargs={}, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index fb97eb1916002..574845ef5a525 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,9 +22,10 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" from functools import cached_property, partial -from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, - Tuple, Type, TypedDict, Union) +from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, + Set, Tuple, Type, TypedDict, Union) +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -54,10 +55,11 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, + MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement) + ProcessorInputs, PromptReplacement) from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -229,9 +231,9 @@ class Qwen2VisionAttention(nn.Module): def __init__( self, - embed_dim: Optional[int] = None, - num_heads: Optional[int] = None, - projection_size: Optional[int] = None, + embed_dim: int, + num_heads: int, + projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -264,7 +266,7 @@ def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -347,7 +349,7 @@ def __init__( num_heads: int, mlp_ratio: float, act_layer: Type[nn.Module] = QuickGELU, - norm_layer: Type[nn.Module] = None, + norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -384,7 +386,7 @@ def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, - in_chans: int = 3, + in_channels: int = 3, embed_dim: int = 1152, ) -> None: super().__init__() @@ -392,8 +394,8 @@ def __init__( self.temporal_patch_size = temporal_patch_size self.embed_dim = embed_dim - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_chans, + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, @@ -413,7 +415,7 @@ def __init__( self, d_model: int, context_dim: int, - norm_layer: Type[nn.Module] = None, + norm_layer: Optional[Callable[[int], nn.Module]] = None, spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -489,15 +491,15 @@ def __init__( ) -> None: super().__init__() - patch_size: int = vision_config.patch_size - temporal_patch_size: int = vision_config.temporal_patch_size - spatial_merge_size: int = vision_config.spatial_merge_size - in_chans: int = vision_config.in_chans - hidden_size: int = vision_config.hidden_size - embed_dim: int = vision_config.embed_dim - depth: int = vision_config.depth - num_heads: int = vision_config.num_heads - mlp_ratio: float = vision_config.mlp_ratio + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio self.spatial_merge_size = spatial_merge_size self.num_heads = num_heads @@ -506,7 +508,7 @@ def __init__( self.patch_embed = Qwen2VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, - in_chans=in_chans, + in_channels=in_channels, embed_dim=embed_dim, ) @@ -733,8 +735,12 @@ def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": if k == "video": # Special case since even a single item can be a list multi_data[k] = ( # type: ignore[index] - v if (isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment] - or is_list_of(v, list)) else [v] + v if ( + isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment] + or is_list_of(v, list) + or isinstance(v[0], (np.ndarray, torch.Tensor)) + and v[0].ndim == 4 + ) else [v] ) elif k in ("image", "audio"): multi_data[k] = ( # type: ignore[index] @@ -754,6 +760,12 @@ def get_item_counts(self) -> Mapping[str, int]: for m, items in self.items() } + def has_embedding_inputs(self) -> bool: + return any( + isinstance(items, dict) or any( + isinstance(item, torch.Tensor) for item in items) + for items in self.values()) + class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): @@ -784,7 +796,7 @@ def _get_hf_processor( return hf_processor - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -805,7 +817,7 @@ def _get_processor_data( and v[0].ndim == 2): # Pass through embedding inputs (multi) passthrough_data[f"{k}_embeds"] = v - else: + elif len(v) > 0: # Map keys to plural form, e.g.: image -> images processor_data[f"{k}s"] = v else: @@ -816,8 +828,8 @@ def _get_processor_data( def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() image_processor = _get_image_processor(hf_processor) @@ -831,7 +843,9 @@ def _get_prompt_replacements( merge_length = image_processor.merge_size**2 def get_replacement_qwen2vl(item_idx: int, modality: str): - grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx] + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + num_tokens = grid_thw.prod() // merge_length return placeholder[modality] * num_tokens @@ -844,11 +858,40 @@ def get_replacement_qwen2vl(item_idx: int, modality: str): ) for modality in ("image", "video") ] + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist() + image_slices = [ + slice(image_slice_idxs[i], image_slice_idxs[i + 1]) + for i in range(len(image_grid_thw)) + ] + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist() + video_slices = [ + slice(video_slice_idxs[i], video_slice_idxs[i + 1]) + for i in range(len(video_grid_thw)) + ] + + return dict( + pixel_values=MultiModalFieldConfig.flat("image", image_slices), + image_embeds=MultiModalFieldConfig.flat("image", image_slices), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat( + "video", video_slices), + video_embeds=MultiModalFieldConfig.flat("video", video_slices), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - num_images = mm_counts["image"] + num_images = mm_counts.get("image", 0) hf_processor = self._get_hf_processor() image_token: str = hf_processor.image_token image_processor = _get_image_processor(hf_processor) @@ -869,7 +912,6 @@ def _get_dummy_mm_inputs( return ProcessorInputs( prompt_text=image_token * num_images, mm_data=data, - mm_processor_kwargs={}, ) @@ -950,9 +992,7 @@ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): return None return quant_config - def _validate_and_reshape_mm_tensor(self, - mm_input: Union[torch.Tensor, - List[torch.Tensor]], + def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " @@ -962,7 +1002,8 @@ def _validate_and_reshape_mm_tensor(self, return mm_input if mm_input.ndim != 3: raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim}") + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") return torch.concat(list(mm_input)) else: return torch.concat(mm_input) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 509ad9e580ddf..7b4aeeec5f403 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -23,10 +23,11 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, - PromptReplacement) + ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.utils import is_list_of @@ -72,11 +73,19 @@ def get_ultravox_max_audio_tokens(ctx: InputContext): class UltravoxMultiModalProcessor(BaseMultiModalProcessor): + def _get_hf_processor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> ProcessorMixin: + return self.ctx.get_hf_processor() + def _get_feature_extractor(self) -> WhisperFeatureExtractor: hf_processor = self._get_hf_processor() return hf_processor.audio_processor.feature_extractor # type: ignore - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -84,33 +93,41 @@ def _get_processor_data( feature_extractor = self._get_feature_extractor() mm_items.resample_audios(feature_extractor.sampling_rate) - return super()._get_processor_data(mm_items) + return super()._get_hf_mm_data(mm_items) def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: - processor_data = dict(processor_data) - audios = processor_data.pop("audios", []) + # Text-only input not supported in composite processor + if not mm_data: + tokenizer = self._get_tokenizer() + + prompt_ids = tokenizer.encode( + prompt, + add_special_tokens=False, # type: ignore + ) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) if not audios: return super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=mm_data, + mm_kwargs=mm_kwargs, ) feature_extractor = self._get_feature_extractor() - mm_processor_kwargs = dict( - **mm_processor_kwargs, + mm_kwargs = dict( + **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, ) - # Already resampled by _get_processor_data + # Already resampled by _get_hf_mm_data assert is_list_of(audios, np.ndarray) # Ultravox processor doesn't support multiple inputs, @@ -119,13 +136,12 @@ def _call_hf_processor( shared_outputs = {} for audio in audios: # NOTE: Ultravox processor accepts "audio" instead of "audios" - item_processor_data = dict(**processor_data, audio=audio) + item_processor_data = dict(**mm_data, audio=audio) item_outputs = super()._call_hf_processor( - hf_processor, prompt=prompt, - processor_data=item_processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_data=item_processor_data, + mm_kwargs=mm_kwargs, ) audio_features.append(item_outputs.pop("audio_values")[0]) @@ -139,17 +155,28 @@ def _call_hf_processor( ) return BatchFeature(combined_outputs) + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + audio_features=MultiModalFieldConfig.batched("audio"), + audio_token_len=MultiModalFieldConfig.batched("audio"), + audio_embeds=MultiModalFieldConfig.batched("audio"), + ) + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() placeholder = hf_processor.audio_token_replacement # type: ignore def get_replacement_ultravox(item_idx: int): - audio_token_len = hf_inputs["audio_token_len"][item_idx] + audio_token_len = out_mm_kwargs["audio_token_len"][item_idx] return placeholder * audio_token_len return [ @@ -168,14 +195,13 @@ def _get_dummy_mm_inputs( sampling_rate = feature_extractor.sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate - audio_count = mm_counts["audio"] + audio_count = mm_counts.get("audio", 0) audio = np.zeros(audio_len) data = {"audio": [audio] * audio_count} return ProcessorInputs( prompt_text="<|audio|>" * audio_count, mm_data=data, - mm_processor_kwargs={}, ) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 10488e24b30cc..cdda6f8052794 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -297,35 +297,37 @@ def from_seq_group( ``MultiModalPlaceholderMap`` that relates the multi-modal embedding vectors to their corresponding placeholders. - Consider the following scenarios: + Examples: - Prompt: |AAAA BBBB What's in these images?| - Positions: |.................................| + .. code-block:: - images = [A, B] - src_ranges = [(0, 4), (4, 8)] - dest_ranges = [(0, 4), (5, 9)] + Prompt: |AAAA BBBB What's in these images?| + Positions: |.................................| - Prompt: |AAAA BBBB What's in these images?| - Positions: | ..... | + images = [A, B] + src_ranges = [(0, 4), (4, 8)] + dest_ranges = [(0, 4), (5, 9)] - images = [A, B] - src_ranges = [(2, 4), (4, 6)] - dest_ranges = [(0, 2), (3, 5)] + Prompt: |AAAA BBBB What's in these images?| + Positions: | ..... | - Prompt: |AAAA BBBB What's in these images?| - Positions: | ......... | + images = [A, B] + src_ranges = [(2, 4), (4, 6)] + dest_ranges = [(0, 2), (3, 5)] - images = [B] - src_ranges = [(0, 4)] - dest_ranges = [(0, 4)] + Prompt: |AAAA BBBB What's in these images?| + Positions: | ......... | - Prompt: |AAAA BBBB What's in these images?| - Positions: | .......................| + images = [B] + src_ranges = [(0, 4)] + dest_ranges = [(0, 4)] - images = [] - src_ranges = [] - dest_ranges = [] + Prompt: |AAAA BBBB What's in these images?| + Positions: | .......................| + + images = [] + src_ranges = [] + dest_ranges = [] """ seq_mm_data = seq_group.multi_modal_data seq_mm_placeholders = seq_group.multi_modal_placeholders diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 9ecae2c1ca2bf..1fbda6e0b8750 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -1,12 +1,16 @@ +from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import (Any, Dict, List, Literal, Mapping, Sequence, Tuple, - TypedDict, TypeVar, Union, cast, final) +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from typing import (Any, Literal, NamedTuple, TypedDict, TypeVar, Union, cast, + final) import numpy as np import torch import torch.types from PIL.Image import Image -from typing_extensions import NotRequired, TypeAlias +from transformers import BatchFeature +from typing_extensions import NotRequired, TypeAlias, assert_never from vllm.utils import JSONTree, is_list_of, json_map_leaves @@ -44,7 +48,7 @@ """ # yapf: enable -MultiModalData: TypeAlias = Union[_T, List[_T]] +MultiModalData: TypeAlias = Union[_T, list[_T]] """ Either a single data item, or a list of data items. @@ -79,13 +83,135 @@ class MultiModalDataBuiltins(TypedDict, total=False): """ +class ImageSize(NamedTuple): + width: int + height: int + + +class MultiModalDataItems(UserDict[str, list[Any]]): + """ + As :class:`MultiModalDataDict`, but normalized such that each entry + corresponds to a list. + """ + + @staticmethod + def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": + """ + Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. + """ + multi_data = MultiModalDataItems() + + for k, v in data.items(): + # TODO: Make a separate modality for embedding inputs + # to avoid confusion + # yapf: disable + if k == "video": + # Special case since even a single item can be a list + multi_data[k] = ( # type: ignore[index] + v if ( + isinstance(v, torch.Tensor) + or is_list_of(v, list) + or isinstance(v[0], (np.ndarray, torch.Tensor)) + and v[0].ndim == 4 + ) else [v] + ) + elif k in ("image", "audio"): + multi_data[k] = ( # type: ignore[index] + v if isinstance(v, (torch.Tensor, list)) else [v] + ) + else: + multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] + # yapf: enable + + return multi_data + + # NOTE: When a field (e.g. `images`) doesn't exist, directly appending to + # `self.images` doesn't update this dictionary, which may be confusing + # We annotate the getter methods as `Sequence` to prevent others from + # trying to update the list in this way + @property + def images(self) -> Sequence[ImageItem]: + return self.get("image", []) + + @property + def videos(self) -> Sequence[VideoItem]: + return self.get("video", []) + + @property + def audios(self) -> Sequence[AudioItem]: + return self.get("audio", []) + + def get_item_counts(self) -> Mapping[str, int]: + return {m: len(items) for m, items in self.items()} + + def has_embedding_inputs(self) -> bool: + return any( + any(isinstance(item, torch.Tensor) for item in items) + for items in self.values()) + + def get_image_size(self, item_idx: int) -> ImageSize: + image = self.images[item_idx] + + if isinstance(image, Image): + return ImageSize(*image.size) + if isinstance(image, (np.ndarray, torch.Tensor)): + _, h, w = image.shape + return ImageSize(w, h) + + assert_never(image) + + def get_audio_with_sr( + self, + item_idx: int, + *, + default_sr: float, + ) -> tuple[np.ndarray, float]: + audio = self.audios[item_idx] + + if isinstance(audio, tuple): + return audio + if isinstance(audio, list): + return np.array(audio), default_sr + if isinstance(audio, np.ndarray): + return audio, default_sr + + assert_never(audio) + + def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None: + """ + If :code:`drop_sr=True`, the audio items in this dictionary are updated + to be NumPy arrays which implicitly means that their sampling rate is + the same as the model's expected sampling rate; otherwise, they remain + as :code:`(audio, new_sr)` tuples. + """ + # Avoid circular import + from .audio import resample_audio + + if not self.audios: + return + + new_audios = [] + for item_idx in range(len(self.audios)): + audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr) + audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr) + + new_audios.append(audio if drop_sr else (audio, new_sr)) + + self["audio"] = new_audios + + class PlaceholderRange(TypedDict): """ Placeholder location information for multi-modal data. - For example: - Prompt: AAAA BBBB What is in these images? + Example: + + Prompt: :code:`AAAA BBBB What is in these images?` + Images A and B will have: + + .. code-block:: + A: { "offset": 0, "length": 4 } B: { "offset": 5, "length": 4 } """ @@ -97,25 +223,256 @@ class PlaceholderRange(TypedDict): """The length of the placeholder.""" -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor, - Tuple[torch.Tensor, ...]] +NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, + tuple[torch.Tensor, ...]] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ -BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] + +def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: + """Equality check between :data:`NestedTensors` objects.""" + if isinstance(a, torch.Tensor): + return isinstance(b, torch.Tensor) and bool((a == b).all().item()) + elif isinstance(b, torch.Tensor): + return isinstance(a, torch.Tensor) and bool((b == a).all().item()) + + if isinstance(a, list): + return (isinstance(b, list) + and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))) + if isinstance(b, list): + return (isinstance(a, list) + and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))) + + # Both a and b are scalars + return a == b + + +BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via :meth:`MultiModalKwargs.batch`. """ +@dataclass(frozen=True) +class MultiModalFieldItem: + """ + Contains metadata and data in :class:`MultiModalKwargs` + corresponding to a data item in :class:`MultiModalDataItems`. + """ + field: "BaseMultiModalField" + data: NestedTensors + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + + return (self.field == other.field + and nested_tensors_equal(self.data, other.data)) + + +@dataclass(frozen=True) +class BaseMultiModalField(ABC): + """Abstract base class for a field in :class:`MultiModalKwargs`.""" + key: str + modality: str + + @abstractmethod + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + raise NotImplementedError + + def _build_item(self, data: NestedTensors) -> MultiModalFieldItem: + return MultiModalFieldItem(self, data) + + def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem: + """Merge multiple instances of :class:`MultiModalFieldItem` together.""" + fields = [item.field for item in batch] + if len(set(fields)) > 1: + raise ValueError(f"Cannot merge different {fields=}") + + data = self._reduce_data([item.data for item in batch]) + + return self._build_item(data) + + +@dataclass(frozen=True) +class MultiModalBatchedField(BaseMultiModalField): + """ + A :class:`BaseMultiModalField` implementation where an item is obtained by + directly indexing into the first dimension of the underlying data. + """ + + def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]: + return [self._build_item(item) for item in batch] + + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + first_shape = batch[0].shape + if all(item.shape == first_shape for item in batch): + return torch.stack(batch) + + return batch + + +@dataclass(frozen=True) +class MultiModalFlatField(BaseMultiModalField): + """ + A :class:`BaseMultiModalField` implementation where an item is obtained by + slicing along the first dimension of the underlying data. + """ + + def build_items( + self, + batch: NestedTensors, + slices: Sequence[slice], + ) -> list[MultiModalFieldItem]: + return [self._build_item(batch[slice_]) for slice_ in slices] + + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + first_shape = batch[0].shape + if all(item.shape[1:] == first_shape[1:] for item in batch): + return torch.concat(batch) + + return [elem for item in batch for elem in item] + + +class MultiModalFieldConfig: + + @staticmethod + def batched(modality: str): + return MultiModalFieldConfig( + field_cls=MultiModalBatchedField, + modality=modality, + ) + + @staticmethod + def flat(modality: str, slices: Sequence[slice]): + return MultiModalFieldConfig( + field_cls=MultiModalFlatField, + modality=modality, + slices=slices, + ) + + def __init__( + self, + field_cls: type[BaseMultiModalField], + modality: str, + **field_config: Any, + ) -> None: + super().__init__() + + self._field_cls = field_cls + self._modality = modality + self._field_config = field_config + + def build_items( + self, + key: str, + batch: NestedTensors, + ) -> list[MultiModalFieldItem]: + field = self._field_cls(key=key, modality=self._modality) + return field.build_items(batch, **self._field_config) # type: ignore + + class MultiModalKwargs(UserDict[str, NestedTensors]): """ A dictionary that represents the keyword arguments to :meth:`~torch.nn.Module.forward`. + + The metadata :code:`items_by_key` defines how to split batched keyword + arguments corresponding to each data item in :class:`MultiModalDataItems`: + + - For a keyword argument, we can access the :code:`i` th item in the batch + via :code:`items_by_key[key][i]`. + - We can gather the keyword arguments belonging to a modality by finding + the keys with items that belong to that modality, then accessing + the :code:`i` th item in the batch for each such key. + + Example: + + .. code-block:: python + + # All items belong to the "image" modality + items_by_key={ + "pixel_values": [a, b, c, d], # "image" modality + "image_grid_thw": [e, f, g, h], # "image" modality + "pixel_values_video": [h, i, j], # "video" modality + "video_grid_thw": [k, l, m], # "video" modality + } + + - The keyword arguments belonging to the first image are + :code:`{"pixel_values": a, "image_grid_thw": e}`. + - The keyword arguments belonging to the second video are + :code:`{"pixel_values_video": i, "video_grid_thw": l}`. """ + @staticmethod + def from_hf_inputs( + hf_inputs: BatchFeature, + config_by_key: Mapping[str, MultiModalFieldConfig], + *, + enable_sanity_checks: bool = False, + ): + # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` + # We assume that those fields are not used in vLLM + items_by_key = { + key: config.build_items(key, batch) + for key, config in config_by_key.items() + if (batch := hf_inputs.get(key)) is not None + } + + return MultiModalKwargs.from_items_by_key( + items_by_key, + enable_sanity_checks=enable_sanity_checks, + ) + + @staticmethod + def from_items_by_key( + items_by_key: Mapping[str, list[MultiModalFieldItem]], + *, + enable_sanity_checks: bool = False, + ) -> "MultiModalKwargs": + data = { + key: items[0].field.reduce(items).data + for key, items in items_by_key.items() + } + + return MultiModalKwargs(data, + items_by_key=items_by_key, + enable_sanity_checks=enable_sanity_checks) + + def __init__( + self, + data: Mapping[str, NestedTensors], + *, + items_by_key: Mapping[str, list[MultiModalFieldItem]] = {}, + enable_sanity_checks: bool = False, + ) -> None: + super().__init__(data) + + # Shallow copy to avoid footgun in case a defaultdict is passed in + self._items_by_key = dict(items_by_key) + + keys_by_modality = defaultdict[str, set[str]](set) + for key, items in items_by_key.items(): + for item in items: + keys_by_modality[item.field.modality].add(key) + + self._keys_by_modality = dict(keys_by_modality) + + if enable_sanity_checks: + for modality, keys in keys_by_modality.items(): + items_in_modality = {k: items_by_key[k] for k in keys} + batch_sizes = {k: len(v) for k, v in items_in_modality.items()} + batch_size = next(iter(batch_sizes.values()), 0) + assert all(bs == batch_size + for bs in batch_sizes.values()), dict( + modality=modality, + batch_sizes=batch_sizes, + items_by_key=items_by_key) + @staticmethod def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: """ @@ -139,7 +496,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: # Only tensors (not lists) can be stacked. return stacked - tensors_ = cast(List[torch.Tensor], stacked) + tensors_ = cast(list[torch.Tensor], stacked) if any(t.shape != tensors_[0].shape for t in tensors_): # The tensors have incompatible shapes and can't be stacked. return tensors_ @@ -147,7 +504,7 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: return torch.stack(tensors_) @staticmethod - def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: + def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs: """ Batch multiple inputs together into a dictionary. @@ -162,7 +519,7 @@ def batch(inputs_list: List["MultiModalKwargs"]) -> BatchedTensorInputs: # We need to consider the case where each item in the batch # contains different modalities (i.e. different keys). - item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) + item_lists = defaultdict[str, list[NestedTensors]](list) for inputs in inputs_list: for k, v in inputs.items(): @@ -188,6 +545,57 @@ def as_kwargs( return cast(BatchedTensorInputs, json_mapped) + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + if self._items_by_key != other._items_by_key: + return False + + ks = self.keys() + return (ks == other.keys() + and all(nested_tensors_equal(self[k], other[k]) for k in ks)) + + def get_item(self, key: str, item_index: int) -> MultiModalFieldItem: + return self._items_by_key[key][item_index] + + def get_items_by_modality( + self, + modality: str, + item_index: int, + ) -> Mapping[str, MultiModalFieldItem]: + """ + Get the keyword arguments corresponding to an item identified by + its modality and index. + """ + keys_to_gather = self._keys_by_modality[modality] + + return { + key: self.get_item(key, item_index) + for key in keys_to_gather if key in self + } + + @staticmethod + def from_items_by_modality( + items_by_modality: Mapping[str, list[Mapping[str, + MultiModalFieldItem]]], + *, + enable_sanity_checks: bool = False, + ) -> "MultiModalKwargs": + """ + Construct a new :class:`MultiModalKwargs` from multiple items returned + by :meth:`get_fields_by_modality`. + """ + items_by_key = defaultdict[str, list[MultiModalFieldItem]](list) + for fields in items_by_modality.values(): + for field in fields: + for k, v in field.items(): + items_by_key[k].append(v) + + return MultiModalKwargs.from_items_by_key( + items_by_key, + enable_sanity_checks=enable_sanity_checks, + ) + MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] """ @@ -207,16 +615,16 @@ class MultiModalInputsV2(TypedDict): prompt: str """The processed prompt text.""" - prompt_token_ids: List[int] + prompt_token_ids: list[int] """The processed token IDs which includes placeholder tokens.""" - token_type_ids: NotRequired[List[int]] + token_type_ids: NotRequired[list[int]] """The token type IDs of the prompt.""" mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" - mm_hashes: NotRequired[List[str]] + mm_hashes: NotRequired[list[str]] """The hashes of the multi-modal data.""" mm_placeholders: MultiModalPlaceholderDict diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 6baf19d675d50..3ece0762e3228 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,6 @@ +import pickle import re from abc import ABC, abstractmethod -from collections import UserDict from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field from functools import lru_cache @@ -8,19 +8,18 @@ import numpy as np import torch +from blake3 import blake3 from PIL.Image import Image from transformers import BatchFeature, ProcessorMixin -from typing_extensions import assert_never from vllm.inputs import DummyData, InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import flatten_2d_lists, full_groupby, is_list_of +from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of -from .audio import resample_audio -from .inputs import (AudioItem, ImageItem, MultiModalDataDict, - MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, - VideoItem) +from .inputs import (MultiModalDataDict, MultiModalDataItems, + MultiModalFieldConfig, MultiModalFieldItem, + MultiModalInputsV2, MultiModalKwargs, PlaceholderRange) logger = init_logger(__name__) @@ -201,111 +200,6 @@ def get_replacement(self, item_idx: int) -> _BoundPromptSequence: return bound_replacement -class ImageSize(NamedTuple): - width: int - height: int - - -class MultiModalDataItems(UserDict[str, list[Any]]): - """ - As :class:`MultiModalDataDict`, but normalized such that each entry - corresponds to a list. - """ - - @staticmethod - def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": - """ - Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`. - """ - multi_data = MultiModalDataItems() - - for k, v in data.items(): - # TODO: Make a separate modality for embedding inputs - # to avoid confusion - # yapf: disable - if k == "video": - # Special case since even a single item can be a list - multi_data[k] = ( # type: ignore[index] - v if (isinstance(v, torch.Tensor) - or is_list_of(v, list)) else [v] - ) - elif k in ("image", "audio"): - multi_data[k] = ( # type: ignore[index] - v if isinstance(v, (torch.Tensor, list)) else [v] - ) - else: - multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index] - # yapf: enable - - return multi_data - - # NOTE: When a field (e.g. `images`) doesn't exist, directly appending to - # `self.images` doesn't update this dictionary, which may be confusing - # We annotate the getter methods as `Sequence` to prevent others from - # trying to update the list in this way - @property - def images(self) -> Sequence[ImageItem]: - return self.get("image", []) - - @property - def videos(self) -> Sequence[VideoItem]: - return self.get("video", []) - - @property - def audios(self) -> Sequence[AudioItem]: - return self.get("audio", []) - - def get_item_counts(self) -> Mapping[str, int]: - return {m: len(items) for m, items in self.items()} - - def get_image_size(self, item_idx: int) -> ImageSize: - image = self.images[item_idx] - - if isinstance(image, Image): - return ImageSize(*image.size) - if isinstance(image, (np.ndarray, torch.Tensor)): - _, h, w = image.shape - return ImageSize(w, h) - - assert_never(image) - - def get_audio_with_sr( - self, - item_idx: int, - *, - default_sr: float, - ) -> tuple[np.ndarray, float]: - audio = self.audios[item_idx] - - if isinstance(audio, tuple): - return audio - if isinstance(audio, list): - return np.array(audio), default_sr - if isinstance(audio, np.ndarray): - return audio, default_sr - - assert_never(audio) - - def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None: - """ - If :code:`drop_sr=True`, the audio items in this dictionary are updated - to be NumPy arrays which implicitly means that their sampling rate is - the same as the model's expected sampling rate; otherwise, they remain - as :code:`(audio, new_sr)` tuples. - """ - if not self.audios: - return - - new_audios = [] - for item_idx in range(len(self.audios)): - audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr) - audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr) - - new_audios.append(audio if drop_sr else (audio, new_sr)) - - self["audio"] = new_audios - - class _TokenMatch(NamedTuple): start_idx: int end_idx: int @@ -583,11 +477,124 @@ def iter_placeholders( ) -class ProcessorInputs(NamedTuple): - """Keyword arguments to :meth:`BaseMultiModalProcessor`""" +@dataclass +class ProcessorInputs: + """Keyword arguments to :meth:`BaseMultiModalProcessor`.""" prompt_text: str mm_data: MultiModalDataDict - mm_processor_kwargs: Mapping[str, object] + hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) + + +class ProcessingCache: + + def __init__(self, capacity: int) -> None: + super().__init__() + + # DEBUG: Set to None to disable + self.debug_cache_hit_ratio_steps: Optional[int] = None + + self._cache = LRUCache[str, Mapping[str, + MultiModalFieldItem]](capacity) + + def _maybe_log_cache_stats(self) -> None: + steps = self.debug_cache_hit_ratio_steps + if not steps: + return + + cache_stats = self._cache.stat() + if cache_stats.total % steps == 0: + logger.debug("ProcessingCache: hit_ratio = %.2f", + cache_stats.hit_ratio) + + def _serialize_item(self, obj: object) -> bytes: + # Simple cases + if isinstance(obj, str): + return obj.encode("utf-8") + if isinstance(obj, bytes): + return obj + if isinstance(obj, Image): + return obj.tobytes() + + # Convertible to NumPy arrays + if isinstance(obj, torch.Tensor): + obj = obj.numpy() + if isinstance(obj, (int, float)): + obj = np.array(obj) + if isinstance(obj, np.ndarray): + return obj.tobytes() + + logger.warning( + "No serialization method found for %s. " + "Falling back to pickle.", type(obj)) + + return pickle.dumps(obj) + + def _item_to_bytes( + self, + key: str, + obj: object, + ) -> Iterable[tuple[bytes, bytes]]: + # Recursive cases + if isinstance(obj, (list, tuple)): + for i, elem in enumerate(obj): + yield from self._item_to_bytes(f"{key}.{i}", elem) + elif isinstance(obj, dict): + for k, v in obj.items(): + yield from self._item_to_bytes(f"{key}.{k}", v) + else: + key_bytes = self._serialize_item(key) + value_bytes = self._serialize_item(obj) + yield key_bytes, value_bytes + + def _hash_kwargs(self, **kwargs: object) -> str: + hasher = blake3() + + for k, v in kwargs.items(): + for k_bytes, v_bytes in self._item_to_bytes(k, v): + hasher.update(k_bytes) + hasher.update(v_bytes) + + return hasher.hexdigest() + + def get( + self, + model_id: str, + modality: str, + input_item: object, + input_kwargs: Mapping[str, object], + ) -> Optional[Mapping[str, MultiModalFieldItem]]: + """ + Get a processed multi-modal item from the cache + according to its dependencies, including: + + - The model ID + - The modality of the item + - The original data item passed to the HF processor + - The configuration options of the HF processor + """ + self._maybe_log_cache_stats() + + cache_key = self._hash_kwargs(model_id=model_id, + **{modality: input_item}, + **input_kwargs) + return self._cache.get(cache_key) + + def put( + self, + model_id: str, + modality: str, + input_item: object, + input_kwargs: Mapping[str, object], + output_kwargs: Mapping[str, MultiModalFieldItem], + ) -> None: + """ + Put a processed multi-modal item into the cache + according to its dependencies (see :meth:`get`). + """ + cache_key = self._hash_kwargs(model_id=model_id, + **{modality: input_item}, + **input_kwargs) + self._cache.put(cache_key, output_kwargs) class BaseMultiModalProcessor(ABC): @@ -595,18 +602,24 @@ class BaseMultiModalProcessor(ABC): Abstract base class to process multi-modal inputs to be used in vLLM. """ - def __init__(self, ctx: InputProcessingContext) -> None: + def __init__(self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True) -> None: super().__init__() self.ctx = ctx + self.cache = cache + self.enable_sanity_checks = enable_sanity_checks def __call__( self, prompt: str, mm_data: MultiModalDataDict, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - return self.apply(prompt, mm_data, mm_processor_kwargs) + return self.apply(prompt, mm_data, hf_processor_mm_kwargs) def _get_hf_processor(self) -> ProcessorMixin: """ @@ -624,12 +637,21 @@ def _get_mm_items( ) -> MultiModalDataItems: return MultiModalDataItems.from_dict(mm_data) + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + """Given the HF-processed data, output the metadata of each field.""" + raise NotImplementedError + @abstractmethod def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_inputs: BatchFeature, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: """ Given the original multi-modal items for this modality @@ -651,7 +673,7 @@ def _find_placeholders( return list( iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts)) - def _get_processor_data( + def _get_hf_mm_data( self, mm_items: MultiModalDataItems, ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -669,7 +691,7 @@ def _get_processor_data( and v[0].ndim == 2): # Pass through embedding inputs (multi) passthrough_data[f"{k}_embeds"] = v - else: + elif len(v) > 0: # Map keys to plural form, e.g.: image -> images processor_data[f"{k}s"] = v else: @@ -679,39 +701,181 @@ def _get_processor_data( def _call_hf_processor( self, - hf_processor: ProcessorMixin, prompt: str, - processor_data: Mapping[str, object], - mm_processor_kwargs: Mapping[str, object], + # Not to be confused with `mm_data` in `self.apply`. + # This refers to the data to be passed to HF processor. + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], ) -> BatchFeature: return self.ctx.call_hf_processor( - hf_processor, - prompt, - processor_data, - mm_processor_kwargs, + self._get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + mm_kwargs, ) def _apply_hf_processor( self, - prompt: str, + prompt_text: str, mm_items: MultiModalDataItems, - mm_processor_kwargs: Mapping[str, object], - ) -> BatchFeature: - # some mm_processor_kwargs may be used in processor initialization - # instead of processor call - hf_processor = self._get_hf_processor(**mm_processor_kwargs) + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs]: + """ + Apply the HF processor on the full prompt text and multi-modal data. + """ + processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + + processed_data = self._call_hf_processor( + prompt=prompt_text, + mm_data=processor_data, + mm_kwargs=hf_processor_mm_kwargs, + ) + processed_data.update(passthrough_data) - processor_data, passthrough_data = self._get_processor_data(mm_items) + prompt_ids, = processed_data.pop("input_ids").tolist() - hf_inputs = self._call_hf_processor( - hf_processor, - prompt=prompt, - processor_data=processor_data, - mm_processor_kwargs=mm_processor_kwargs, + mm_kwargs = MultiModalKwargs.from_hf_inputs( + processed_data, + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), + enable_sanity_checks=self.enable_sanity_checks, ) - hf_inputs.update(passthrough_data) - return hf_inputs + return prompt_ids, mm_kwargs + + def _apply_hf_processor_missing( + self, + prompt_text: str, + mm_missing_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ): + """ + Apply the HF processor on the full prompt text, but only on the + multi-modal data that are missing from the cache. + + Note: We pass prompt text and multi-modal data into the HF processor + in separate calls to avoid HF prompt replacement being done for + cached items; instead, we rely on our own prompt replacement logic + for the full text. + """ + mm_missing_counts = mm_missing_data_items.get_item_counts() + + prompt_ids, _ = self._apply_hf_processor( + prompt_text=prompt_text, + mm_items=MultiModalDataItems({}), + hf_processor_mm_kwargs={}, + ) + + # Some HF processors (e.g. Qwen2-VL) expect corresponding + # multi-modal tokens to be in the prompt text + dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts) + + _, mm_missing_kwargs = self._apply_hf_processor( + prompt_text=dummy_inputs.prompt_text, + mm_items=mm_missing_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + return prompt_ids, mm_missing_kwargs + + def _cached_apply_hf_processor( + self, + prompt_text: str, + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[list[int], MultiModalKwargs]: + """ + Apply the HF processor on the full prompt text, + caching the results and reusing cached results. + """ + cache = self.cache + model_id = self.ctx.model_config.model + + if cache is None or mm_data_items.has_embedding_inputs(): + return self._apply_hf_processor( + prompt_text=prompt_text, + mm_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + mm_maybe_cached_field_items = { + modality: [ + cache.get(model_id, modality, item, hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_data_items.items() + } + + mm_missing_idxs = { + modality: [idx for idx, out in enumerate(fields) if out is None] + for modality, fields in mm_maybe_cached_field_items.items() + } + mm_missing_data = { + modality: [mm_data_items[modality][idx] for idx in idxs] + for modality, idxs in mm_missing_idxs.items() + } + mm_missing_data_items = self._get_mm_items(mm_missing_data) + + prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( + prompt_text=prompt_text, + mm_missing_data_items=mm_missing_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + + mm_missing_next_idx = { + modality: 0 + for modality in mm_missing_data_items + } + + mm_merged_field_items = dict[str, list[Mapping[str, + MultiModalFieldItem]]]() + for modality, modal_items_lst in mm_maybe_cached_field_items.items(): + merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]() + + for idx, modal_items in enumerate(modal_items_lst): + if modal_items is None: + modal_items = mm_missing_kwargs.get_items_by_modality( + modality, + mm_missing_next_idx[modality], + ) + + cache.put( + model_id, + modality, + mm_data_items[modality][idx], + hf_processor_mm_kwargs, + modal_items, + ) + + mm_missing_next_idx[modality] += 1 + + merged_modal_items_lst.append(modal_items) + + mm_merged_field_items[modality] = merged_modal_items_lst + + if self.enable_sanity_checks: + mm_missing_counts = mm_missing_data_items.get_item_counts() + assert all( + item_count == mm_missing_counts[modality] + for modality, item_count in mm_missing_next_idx.items()), dict( + mm_missing_next_idx=mm_missing_next_idx, + mm_missing_counts=mm_missing_counts) + + mm_kwargs = MultiModalKwargs.from_items_by_modality( + mm_merged_field_items, + enable_sanity_checks=self.enable_sanity_checks, + ) + + if self.enable_sanity_checks: + mm_item_counts = mm_data_items.get_item_counts() + + for modality, item_count in mm_item_counts.items(): + for item_idx in range(item_count): + try: + mm_kwargs.get_items_by_modality(modality, item_idx) + except Exception as e: + # Make it easy to set a breakpoint in the debugger + raise e + + return prompt_ids, mm_kwargs def _bind_prompt_replacements( self, @@ -730,6 +894,10 @@ def _apply_prompt_replacements( tokenizer = self._get_tokenizer() token_matches = find_token_matches(token_ids, prompt_repls) + mm_match_counts = { + modality: len(matches) + for modality, matches in full_groupby_modality(token_matches) + } # If the search text does not represent a special token, # it may have different token IDs in the prompt, because @@ -742,8 +910,8 @@ def _apply_prompt_replacements( # of the search text in the prompt, we instead perform string # replacement on the decoded token IDs, then encode them back. if all( - len(matches) >= mm_item_counts[modality] - for modality, matches in full_groupby_modality(token_matches) + mm_match_counts.get(modality, 0) >= item_count + for modality, item_count in mm_item_counts.items() ): # yapf: disable token_ids = replace_token_matches( token_ids, @@ -775,7 +943,7 @@ def apply( self, prompt_text: str, mm_data: MultiModalDataDict, - mm_processor_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: """ Process multi-modal inputs to be used in vLLM. @@ -792,20 +960,24 @@ def apply( """ mm_items = self._get_mm_items(mm_data) - hf_inputs = self._apply_hf_processor(prompt_text, mm_items, - mm_processor_kwargs) - prompt_ids, = hf_inputs.pop("input_ids").tolist() - mm_kwargs = MultiModalKwargs(hf_inputs) + prompt_ids, mm_kwargs = self._cached_apply_hf_processor( + prompt_text, + mm_items, + hf_processor_mm_kwargs, + ) - prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs, - mm_processor_kwargs) - all_prompt_repls = self._bind_prompt_replacements(prompt_repls) + unbound_prompt_repls = self._get_prompt_replacements( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls) # If HF processor already inserts placeholder tokens, # there is no need for us to insert them mm_item_counts = mm_items.get_item_counts() - all_placeholders = self._find_placeholders(all_prompt_repls, - prompt_ids, mm_item_counts) + all_placeholders = self._find_placeholders(prompt_repls, prompt_ids, + mm_item_counts) if all_placeholders: tokenizer = self._get_tokenizer() @@ -817,7 +989,7 @@ def apply( all_placeholders, ) = self._apply_prompt_replacements( prompt_ids, - all_prompt_repls, + prompt_repls, mm_item_counts, ) @@ -855,23 +1027,29 @@ def get_dummy_data( from vllm.sequence import SequenceData processor_inputs = self._get_dummy_mm_inputs(mm_counts) - mm_inputs = self.apply(*processor_inputs) + mm_inputs = self.apply( + prompt_text=processor_inputs.prompt_text, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + ) prompt_token_ids = mm_inputs["prompt_token_ids"] placeholders_by_modality = mm_inputs["mm_placeholders"] - total_placeholders_by_modality = dict[str, int]() - for modality, placeholders in placeholders_by_modality.items(): - num_placeholders = sum(item["length"] for item in placeholders) - max_tokens = mm_max_tokens[modality] - - if num_placeholders != max_tokens: - logger.warning( - "The processed dummy data has a total of %d placeholder " - "tokens for the '%s' modality, which is not the expected " - "%d tokens.", num_placeholders, modality, max_tokens) - - total_placeholders_by_modality[modality] = num_placeholders + total_placeholders_by_modality = { + modality: sum(item["length"] for item in placeholders) + for modality, placeholders in placeholders_by_modality.items() + } + expected_placeholders_by_modality = { + modality: mm_max_tokens[modality] + for modality in placeholders_by_modality + } + if total_placeholders_by_modality != expected_placeholders_by_modality: + raise AssertionError( + f"The processed dummy data has a total of " + f"{total_placeholders_by_modality} placeholder tokens, which " + f"is not the expected {expected_placeholders_by_modality} " + "tokens.") total_len = len(prompt_token_ids) if total_len > seq_len: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index ded45a7184b5d..3a5e11867ad9e 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,10 +1,9 @@ import functools from collections import UserDict -from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, +from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol, Sequence, Type, TypeVar) import torch.nn as nn -from typing_extensions import TypeAlias from vllm.inputs import InputProcessingContext from vllm.logger import init_logger @@ -15,7 +14,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import BaseMultiModalProcessor +from .processing import BaseMultiModalProcessor, ProcessingCache from .video import VideoPlugin if TYPE_CHECKING: @@ -23,15 +22,22 @@ logger = init_logger(__name__) +# TODO: Tune the MM cache size +MM_CACHE_SIZE = 256 + N = TypeVar("N", bound=Type[nn.Module]) -MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], - BaseMultiModalProcessor] -""" -Constructs a :class:`MultiModalProcessor` instance from the context. -The processing metadata should be derived from the context. -""" +class MultiModalProcessorFactory(Protocol): + """Constructs a :class:`MultiModalProcessor` instance from the context.""" + + def __call__( + self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + ) -> BaseMultiModalProcessor: + ... class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): @@ -71,6 +77,8 @@ def __init__( self._limits_by_model = _MultiModalLimits() + self._processing_cache = ProcessingCache(MM_CACHE_SIZE) + def register_plugin(self, plugin: MultiModalPlugin) -> None: """ Register a multi-modal plugin so it can be recognized by vLLM. @@ -328,15 +336,18 @@ def wrapper(model_cls: N) -> N: return wrapper - def has_processor(self, model_config: "ModelConfig") -> bool: - """ - Test whether a multi-modal processor is defined for a specific model. - """ + def _get_model_cls(self, model_config: "ModelConfig"): # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) - return model_cls in self._processor_factories + return model_cls + + def has_processor(self, model_config: "ModelConfig") -> bool: + """ + Test whether a multi-modal processor is defined for a specific model. + """ + return self._get_model_cls(model_config) in self._processor_factories def create_processor( self, @@ -346,12 +357,11 @@ def create_processor( """ Create a multi-modal processor for a specific model and tokenizer. """ - - # Avoid circular import - from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) + model_cls = self._get_model_cls(model_config) processor_factory = self._processor_factories[model_cls] ctx = InputProcessingContext(model_config, tokenizer) - return processor_factory(ctx) + cache = (None if model_config.disable_mm_preprocessor_cache else + self._processing_cache) + + return processor_factory(ctx, cache=cache) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index f1523667b0466..b12cc83a22970 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,25 +1,31 @@ from functools import lru_cache from typing import Any, cast +from transformers.processing_utils import ProcessorMixin + def get_processor( processor_name: str, *args: Any, trust_remote_code: bool = False, + processor_cls: type[ProcessorMixin] = ProcessorMixin, **kwargs: Any, ): """Load a processor for the given model name via HuggingFace.""" # don't put this import at the top level # it will call torch.cuda.device_count() from transformers import AutoProcessor - from transformers.processing_utils import ProcessorMixin + + processor_factory = (AutoProcessor + if processor_cls == ProcessorMixin else processor_cls) try: - processor = AutoProcessor.from_pretrained( + processor = processor_factory.from_pretrained( processor_name, *args, trust_remote_code=trust_remote_code, - **kwargs) + **kwargs, + ) except ValueError as e: # If the error pertains to the processor class not existing or not # currently being imported, suggest using the --trust-remote-code flag. diff --git a/vllm/utils.py b/vllm/utils.py index 3d198887021dc..5eb4e8c4180c4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -25,11 +25,11 @@ import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import OrderedDict, UserDict, defaultdict -from collections.abc import Iterable, Mapping +from collections.abc import Hashable, Iterable, Mapping from dataclasses import dataclass, field from functools import lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, - Dict, Generator, Generic, Hashable, List, Literal, + Dict, Generator, Generic, List, Literal, NamedTuple, Optional, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 @@ -194,13 +194,29 @@ def reset(self) -> None: self.counter = 0 +class CacheInfo(NamedTuple): + hits: int + total: int + + @property + def hit_ratio(self) -> float: + if self.total == 0: + return 0 + + return self.hits / self.total + + class LRUCache(Generic[_K, _V]): + """Note: This class is not thread safe!""" def __init__(self, capacity: int) -> None: self.cache = OrderedDict[_K, _V]() self.pinned_items = set[_K]() self.capacity = capacity + self._hits = 0 + self._total = 0 + def __contains__(self, key: _K) -> bool: return key in self.cache @@ -218,6 +234,9 @@ def __setitem__(self, key: _K, value: _V) -> None: def __delitem__(self, key: _K) -> None: self.pop(key) + def stat(self) -> CacheInfo: + return CacheInfo(hits=self._hits, total=self._total) + def touch(self, key: _K) -> None: self.cache.move_to_end(key) @@ -226,8 +245,12 @@ def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: if key in self.cache: value = self.cache[key] self.cache.move_to_end(key) + + self._hits += 1 else: value = default + + self._total += 1 return value def put(self, key: _K, value: _V) -> None: