Skip to content

Commit

Permalink
[Model] Refactor Qwen2-VL to use merged multimodal processor (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#11258)

Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
  • Loading branch information
3 people authored and BKitor committed Dec 30, 2024
1 parent 18e3d32 commit 9c86373
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 527 deletions.
8 changes: 6 additions & 2 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,6 @@ def run_qwen_vl(question: str, modality: str):

# Qwen2-VL
def run_qwen2_vl(question: str, modality: str):
assert modality == "image"

model_name = "Qwen/Qwen2-VL-7B-Instruct"

Expand All @@ -463,8 +462,13 @@ def run_qwen2_vl(question: str, modality: str):
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

if modality == "image":
placeholder = "<|image_pad|>"
elif modality == "video":
placeholder = "<|video_pad|>"

prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Any, Dict, Tuple

import pytest
import torch
from PIL.Image import Image
from transformers import AutoTokenizer

from vllm.inputs import InputContext, token_inputs
from vllm.multimodal import MultiModalRegistry
from vllm.inputs import InputContext, InputProcessingContext

from .....conftest import _ImageAssets
from ....utils import build_model_context
Expand All @@ -20,22 +17,9 @@
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
# input mappers.
@pytest.fixture()
def image_input_mapper_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import (
image_input_mapper_for_qwen2_vl)
return image_input_mapper_for_qwen2_vl


@pytest.fixture()
def input_processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import (
input_processor_for_qwen2_vl)
return input_processor_for_qwen2_vl


@pytest.fixture()
def qwen2_vl_context() -> InputContext:
return build_model_context(model_name=MODEL)
def processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
return Qwen2VLMultiModalProcessor


@pytest.fixture()
Expand All @@ -45,123 +29,77 @@ def get_max_qwen2_vl_image_tokens():
return get_max_qwen2_vl_image_tokens


@pytest.fixture()
def dummy_data_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import dummy_data_for_qwen2_vl
return dummy_data_for_qwen2_vl


@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
({}, 1225),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 324),
])
def test_qwen2_vl_max_image_tokens(get_max_qwen2_vl_image_tokens,
qwen2_vl_context: InputContext,
mm_processor_kwargs: Dict[str, Any],
expected_max_tokens: int):
@pytest.mark.parametrize("model", [MODEL])
def test_qwen2_vl_max_image_tokens(
get_max_qwen2_vl_image_tokens,
model: str,
mm_processor_kwargs: Dict[str, Any],
expected_max_tokens: int,
):
"""Ensure that the max token calc handles min/max pixels properly."""
actual_max_tokens = get_max_qwen2_vl_image_tokens(qwen2_vl_context,
**mm_processor_kwargs)
assert actual_max_tokens == expected_max_tokens


@pytest.mark.parametrize("mm_processor_kwargs,token_count,img_size", [
[{}, 1225, (980, 980)],
[{
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 324, (504, 504)],
])
def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl,
qwen2_vl_context: InputContext,
mm_processor_kwargs: Dict[str, Any],
token_count: int, img_size: Tuple[int, int]):
"""Ensure that the dummy data handles min/max pixels properly."""
seq_len = 3000
hf_config = qwen2_vl_context.get_hf_config()
image_token_id = hf_config.image_token_id

# NOTE: video value is required, but isn't actually used
# when making the dummy data except for error handling currently
dummy_data = dummy_data_for_qwen2_vl(
ctx=qwen2_vl_context,
seq_len=seq_len,
mm_counts={
"image": 1,
"video": 0
},
**mm_processor_kwargs,
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
)
seq_data = dummy_data.seq_data
mm_data = dummy_data.multi_modal_data

# Ensure we have the right number of placeholders for min/max pixel values
assert seq_data.get_token_ids().count(image_token_id) == token_count

# Ensure the images were resized correctly
image = mm_data["image"]
assert isinstance(image, Image)
assert image.size == img_size
actual_max_tokens = get_max_qwen2_vl_image_tokens(
InputContext(ctx.model_config), **mm_processor_kwargs)
assert actual_max_tokens == expected_max_tokens


@pytest.mark.parametrize("mm_processor_kwargs,num_placeholders", [
({}, 1426),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 330),
])
def test_input_processor(input_processor_for_qwen2_vl,
qwen2_vl_context: InputContext,
image_assets: _ImageAssets, num_placeholders: int,
mm_processor_kwargs: Dict[str, Any]):
"""Ensure that the image processor handles min/max pixels properly."""
tokenizer = AutoTokenizer.from_pretrained(MODEL)
prompt = "<|vision_start|><|image_pad|><|vision_end|>"

image = image_assets[0].pil_image
hf_config = qwen2_vl_context.get_hf_config()
image_token_id = hf_config.image_token_id

inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": [image]})

processed_inputs = input_processor_for_qwen2_vl(qwen2_vl_context, inputs,
**mm_processor_kwargs)
assert processed_inputs["prompt_token_ids"].count(
image_token_id) == num_placeholders
assert len(processed_inputs["multi_modal_data"]["image"]) == 1


@pytest.mark.parametrize("mm_processor_kwargs,pixels_shape", [
({}, [5704, 1176]),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, [1320, 1176]),
])
def test_image_mapper_override(qwen2_vl_context: InputContext,
image_assets: _ImageAssets,
mm_processor_kwargs: Dict[str, Any],
pixels_shape: Tuple[int, int]):
"""Ensure that the image mapper handles min/max pixels properly."""
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(qwen2_vl_context.model_config)

image = image_assets[0].pil_image

mapped_output = mm_registry.map_input(
qwen2_vl_context.model_config,
{"image": image},
mm_processor_kwargs=mm_processor_kwargs,
@pytest.mark.parametrize(
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
({}, 1426, (5704, 1176)),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 330, (1320, 1176)),
])
@pytest.mark.parametrize("model", [MODEL])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
processor_for_qwen2_vl,
image_assets: _ImageAssets,
model: str,
mm_processor_kwargs: Dict[str, Any],
expected_toks_per_img: int,
expected_pixels_shape: Tuple[int, int],
num_imgs: int,
):
"""Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
)

# Dimension 0 of pixel values should match the product of image_grid_thw
actual_pixels_shape = mapped_output["pixel_values"].shape
assert list(actual_pixels_shape) == pixels_shape
assert actual_pixels_shape[0] == torch.prod(
mapped_output["image_grid_thw"])
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
images = [image_assets[0].pil_image] * num_imgs

mm_data = {"image": images}

processor = processor_for_qwen2_vl(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)

# Ensure we have the right number of placeholders per num_crops size
hf_processor = processor._get_hf_processor(**mm_processor_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape

assert img_tok_count == expected_toks_per_img * num_imgs
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
assert pixel_shape[1] == expected_pixels_shape[1]
4 changes: 3 additions & 1 deletion vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
audio_len = get_max_qwen2_audio_audio_tokens(self.ctx)
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate

audio_count = mm_counts["audio"]
audio = np.zeros(audio_len)
Expand Down
Loading

0 comments on commit 9c86373

Please sign in to comment.