Skip to content

Commit

Permalink
[Bugfix] Cleanup Pixtral HF code (#11333)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Dec 19, 2024
1 parent 5aef498 commit a0f7d53
Showing 1 changed file with 14 additions and 141 deletions.
155 changes: 14 additions & 141 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from PIL import Image
from transformers import PixtralVisionConfig
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens)
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)

from vllm.attention import AttentionMetadata
from vllm.config import ModelConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
Expand All @@ -27,19 +27,17 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
resolve_visual_encoder_outputs)
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import init_vllm_registered_model, maybe_prefix
from .utils import (init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)

try:
from xformers import ops as xops
Expand Down Expand Up @@ -699,37 +697,14 @@ def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
return grid_length * grid_length


def get_max_pixtral_hf_image_feature_size(
hf_config: PixtralVisionConfig) -> int:
return get_pixtral_hf_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)


def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
return get_max_pixtral_hf_image_feature_size(hf_config)
grid_length = get_pixtral_hf_patch_grid_length(
image_size=hf_config.image_size,
patch_size=hf_config.patch_size,
)


def dummy_seq_data_for_pixtral_hf(
hf_config: PixtralVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
mm_key: str = "image"):
if image_feature_size_override is None:
image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override

return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
# Consider the image_break_token
return (grid_length + 1) * grid_length


def dummy_image_for_pixtral_hf(
Expand Down Expand Up @@ -763,116 +738,14 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
image_width = int(numpy.ceil(image_width / ratio))
image_height = int(numpy.ceil(image_height / ratio))

num_height_tokens, num_width_tokens = _num_image_tokens(
(image_height, image_width), (patch_height, patch_width))
num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens(
(image_height, image_width),
(patch_height, patch_width),
)

return num_width_tokens, num_height_tokens


def input_processor_for_pixtral_hf(
model_config: ModelConfig,
hf_config: PixtralVisionConfig,
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
) -> DecoderOnlyInputs:
assert image_feature_size_override is None, (
"image_feature_size_override is not supported for Pixtral")

multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs

processor = cached_get_processor(model_config.model)

image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_data = [image_data]
elif not is_list_of(image_data, Image.Image):
raise TypeError(f"Invalid image type: {type(image_data)}")

new_prompt = inputs.get("prompt")
new_token_ids = inputs["prompt_token_ids"]

image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token

# Update new_prompt if present
if new_prompt:
parts = new_prompt.split(image_token)
assert len(parts) - 1 == len(image_data)
new_parts = [parts[0]] # Start with the part before any image tokens

for image, next_part in zip(image_data, parts[1:]):
w, h = image.size
(num_width_tokens,
num_height_tokens) = get_pixtral_hf_image_feature_size(
hf_config, image_width=w, image_height=h)

replace_tokens = [image_token] * num_width_tokens + [
image_break_token
]
replace_tokens = replace_tokens * num_height_tokens
replace_tokens[-1] = image_end_token

new_parts.append("".join(replace_tokens))
new_parts.append(next_part)

new_prompt = "".join(new_parts)

# Update new_token_ids
convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids
image_token_id = convert_tokens_to_ids(image_token)
image_break_id = convert_tokens_to_ids(image_break_token)
image_end_id = convert_tokens_to_ids(image_end_token)
placeholder_token_id = -999
# Find all image token indices at once
placeholder_indices = [
idx for idx, token_id in enumerate(new_token_ids)
if token_id == image_token_id
]
assert len(placeholder_indices) == len(image_data)
replace_tokens_list = []
for placeholder_idx, image in zip(placeholder_indices, image_data):
new_token_ids[placeholder_idx] = placeholder_token_id

w, h = image.size
(num_width_tokens,
num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config,
image_width=w,
image_height=h)

replace_tokens = [image_token_id] * num_width_tokens + [image_break_id]
replace_tokens = replace_tokens * num_height_tokens
replace_tokens[-1] = image_end_id
replace_tokens_list.append(replace_tokens)

reverse_offsets: List[int] = []
# Backward iteration for replacement without affecting known indices
for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
reversed(replace_tokens_list)):
reverse_offsets.append(
len(new_token_ids) - placeholder_idx + len(replace_tokens))
new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens

placeholder_ranges: List[PlaceholderRange] = []
for reverse_offset, replace_tokens in zip(reversed(reverse_offsets),
replace_tokens_list):
placeholder_ranges.append(
PlaceholderRange(
offset=len(new_token_ids) - reverse_offset,
length=len(replace_tokens),
))

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})


class PixtralHFMLP(nn.Module):

def __init__(
Expand Down

0 comments on commit a0f7d53

Please sign in to comment.