From a5bba7d234b4e0d82e6a64de82a8497760ed44cf Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 6 Nov 2024 19:41:17 +0800 Subject: [PATCH] [Model] Add Idefics3 support (#9767) Signed-off-by: Jee Jee Li Signed-off-by: B-201 Co-authored-by: B-201 --- docs/source/models/supported_models.rst | 6 + examples/offline_inference_vision_language.py | 17 + ...e_inference_vision_language_multi_image.py | 25 + .../vision_language/test_models.py | 16 + vllm/entrypoints/chat_utils.py | 2 + .../models/idefics2_vision_model.py | 25 +- vllm/model_executor/models/idefics3.py | 632 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 8 files changed, 723 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/idefics3.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 87f45cf695c8d..cdcea70c6cb7d 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -446,6 +446,12 @@ Text Generation - :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc. - - ✅︎ + * - :code:`Idefics3ForConditionalGeneration` + - Idefics3 + - T + I + - :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc. + - + - * - :code:`InternVLChatModel` - InternVL2 - T + I\ :sup:`E+` diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 4fd002caf1763..8d17ce3754515 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -377,6 +377,22 @@ def run_glm4v(question: str, modality: str): return llm, prompt, stop_token_ids +# Idefics3-8B-Llama3 +def run_idefics3(question: str, modality: str): + assert modality == "image" + model_name = "HuggingFaceM4/Idefics3-8B-Llama3" + + llm = LLM(model=model_name, + max_model_len=8192, + max_num_seqs=2, + enforce_eager=True) + prompt = ( + f"<|begin_of_text|>User:{question}\nAssistant:" + ) + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -397,6 +413,7 @@ def run_glm4v(question: str, modality: str): "mllama": run_mllama, "molmo": run_molmo, "glm4v": run_glm4v, + "idefics3": run_idefics3, } diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index d99684078ff3d..7e883568995a4 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -290,6 +290,30 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData: ) +def load_idefics3(question, image_urls: List[str]) -> ModelRequestData: + model_name = "HuggingFaceM4/Idefics3-8B-Llama3" + + # The configuration below has been confirmed to launch on a single L40 GPU. + llm = LLM( + model=model_name, + max_model_len=8192, + max_num_seqs=16, + enforce_eager=True, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "\n".join(f"Image-{i}: \n" + for i, _ in enumerate(image_urls, start=1)) + prompt = f"<|begin_of_text|>User:{placeholders}\n{question}\nAssistant:" # noqa: E501 + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=None, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + model_example_map = { "phi3_v": load_phi3v, "h2ovl_chat": load_h2onvl, @@ -298,6 +322,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData: "qwen2_vl": load_qwen2_vl, "qwen_vl_chat": load_qwenvl_chat, "mllama": load_mllama, + "idefics3": load_idefics3, } diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index cfd2d61f2b633..3dbfaafb781af 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -327,6 +327,22 @@ vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output, prompt_path_encoder=model_utils.qwen_prompt_path_encoder, ), + "idefics3": VLMTestInfo( + models=["HuggingFaceM4/Idefics3-8B-Llama3"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}\nAssistant:", # noqa: E501 + img_idx_to_prompt=lambda idx: "", + max_model_len=8192, + max_num_seqs=2, + auto_cls=AutoModelForVision2Seq, + marks=[ + pytest.mark.skipif( + transformers.__version__ < "4.46.0", + reason="Model introduced in HF >= 4.46.0" + ), + large_gpu_mark(min_gb=48), + ], + ), ### Tensor parallel / multi-gpu broadcast tests "broadcast-chameleon": VLMTestInfo( models=["facebook/chameleon-7b"], diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 0ada0aaacda24..ed4e4399d5514 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -187,6 +187,8 @@ def _placeholder_str(self, modality: ModalityStr, return "<|vision_start|><|image_pad|><|vision_end|>" if model_type == "molmo": return "" + if model_type == "idefics3": + return "" raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 53869b8fa6bd8..b21bc2a3f9ce1 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -15,7 +15,7 @@ # limitations under the License. """PyTorch Idefics2 model.""" -from typing import Optional +from typing import Iterable, Optional, Tuple import torch from torch import nn @@ -29,6 +29,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader class Idefics2VisionEmbeddings(nn.Module): @@ -329,3 +330,25 @@ def forward( encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py new file mode 100644 index 0000000000000..e4c98f22fb16f --- /dev/null +++ b/vllm/model_executor/models/idefics3.py @@ -0,0 +1,632 @@ +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Idefics3 model compatible with HuggingFace weights.""" + +import math +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +import torch.utils.checkpoint +from PIL import Image +from torch import nn +# Temporary solution for transformers below 4.46.0. +from transformers import PretrainedConfig as Idefics3Config + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.multimodal.image import cached_get_image_processor +from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils import is_list_of + +# yapf: disable +from .idefics2_vision_model import ( + Idefics2VisionTransformer as Idefics3VisionTransformer) +# yapf: enable +from .interfaces import SupportsMultiModal +from .llama import LlamaModel +from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings + +logger = init_logger(__name__) + + +class Idefics3ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + """ + rows: List[int] + cols: List[int] + pixel_attention_mask: Optional[torch.BoolTensor] + + +class Idefics3ImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + `hidden_size` must match the hidden size of language model backbone. + """ + + +ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] + + +def input_mapper_for_idefics3( + ctx: InputContext, + data: object, +): + model_config = ctx.model_config + image_processor = cached_get_image_processor( + model_config.model, trust_remote_code=model_config.trust_remote_code) + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + + if isinstance(data, Image.Image): + images = [[data]] + elif is_list_of(data, Image.Image): + images = [data] + else: + raise TypeError(f"Invalid image type: {type(data)}") + + try: + batch_data = image_processor(images, + return_tensors="pt", + return_row_col_info=True).data + except Exception: + logger.error("Failed to process image (%s)", data) + raise + + return MultiModalInputs(batch_data) + + +def _resize_output_size(height: int, + width: int, + max_len: Optional[int] = None, + min_len: Optional[int] = 1, + max_size: Optional[int] = None) -> Tuple[int, int]: + # Set default value for max_len if not provided + max_len = max(height, width) if max_len is None else max_len + aspect_ratio = width / height + + # Handle the maximum size constraint + if max_size is not None: + max_len = min(max_len, max_size) + + # Adjust dimensions according to the aspect ratio + if width >= height: + width = max_len + height = int(width / aspect_ratio) + else: + height = max_len + width = int(height * aspect_ratio) + + # Ensure both width and height are even (if needed) + height += 1 if height % 2 != 0 else 0 + width += 1 if width % 2 != 0 else 0 + + # Ensure dimensions are not smaller than the minimum length + height = max(height, min_len) + width = max(width, min_len) + + return height, width + + +def _get_resize_output_image_size( + image_size: Tuple[int, int], + resolution_max_side: int, + max_image_size: int = 1820, +) -> Tuple[int, int]: + if resolution_max_side > max_image_size: + raise ValueError( + "`resolution_max_side` cannot be larger than `max_image_size`") + + height, width = image_size + + # Find the output size, when rescaling the longest edge to max_len and + # preserving the aspect ratio + height, width = _resize_output_size(height, + width, + max_len=resolution_max_side) + + return height, width + + +def _prompt_split_image(image_seq_len: int, image_rows: int, image_cols: int, + fake_token_around_image: str, image_token: str, + global_img_token: str) -> str: + """ + Prompt with expanded image tokens for when the image is split + into patches. + """ + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += (fake_token_around_image + + f"" + + image_token * image_seq_len) + text_split_images += "\n" + + text_split_images += "\n" + _prompt_single_image( + image_seq_len=image_seq_len, + fake_token_around_image=fake_token_around_image, + image_token=image_token, + global_img_token=global_img_token) + return text_split_images + + +def _prompt_single_image(image_seq_len: int, fake_token_around_image: str, + image_token: str, global_img_token: str): + """Prompt with expanded image tokens for a single image.""" + return (fake_token_around_image + global_img_token + + image_token * image_seq_len + fake_token_around_image) + + +def _get_image_prompt_string(image_rows: int, image_cols: int, + image_seq_len: int, fake_token_around_image: str, + image_token: str, global_img_token: str): + if image_rows == 0 and image_cols == 0: + return _prompt_single_image( + image_seq_len=image_seq_len, + fake_token_around_image=fake_token_around_image, + image_token=image_token, + global_img_token=global_img_token, + ) + return _prompt_split_image(image_seq_len, image_rows, image_cols, + fake_token_around_image, image_token, + global_img_token) + + +def input_processor_for_idefics3(ctx: InputContext, inputs: DecoderOnlyInputs): + multi_modal_data = inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs + + model_config = ctx.model_config + processor = cached_get_processor(model_config.model) + image_processor = processor.image_processor + tokenizer = processor.tokenizer + size = image_processor.size['longest_edge'] + max_image_size = image_processor.max_image_size['longest_edge'] + + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_list = [image_data] + elif is_list_of(image_data, Image.Image): + image_list = image_data + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + image_rows = [] + image_cols = [] + for image in image_list: + height, width = _get_resize_output_image_size(image.size, size) + + rows = math.ceil(height / max_image_size) + cols = math.ceil(width / max_image_size) + image_rows.append(rows) + image_cols.append(cols) + image_rows = [image_rows] + image_cols = [image_cols] + + n_images_in_text = [] + + text = inputs.get("prompt") + if text is not None: + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, " + "or a list of strings") + + fake_image_token = processor.fake_image_token.content + image_token = processor.image_token.content + global_img_token = processor.global_image_tag + + prompt_strings = [] + for sample, sample_rows, sample_cols in zip(text, image_rows, + image_cols): + n_images_in_text.append(sample.count(image_token)) + + # Replace the image token with fake tokens around the expanded + # image token sequence of length `image_seq_len` + image_prompt_strings = [] + for n_rows, n_cols in zip(sample_rows, sample_cols): + image_prompt_string = _get_image_prompt_string( + n_rows, + n_cols, + processor.image_seq_len, + image_token=image_token, + fake_token_around_image=fake_image_token, + global_img_token=global_img_token, + ) + image_prompt_strings.append(image_prompt_string) + + split_sample = sample.split(image_token) + if len(split_sample) == 0: + raise ValueError( + "The image token should be present in the text.") + + # Place in the image prompt strings where the image tokens are + sample = split_sample[0] + for i, image_prompt_string in enumerate(image_prompt_strings): + sample += image_prompt_string + split_sample[i + 1] + prompt_strings.append(sample) + + prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids + + return token_inputs( + prompt_token_ids=prompt_token_ids, + prompt=prompt_strings[0], + multi_modal_data=multi_modal_data, + ) + + +def get_max_idefics3_image_tokens(ctx: InputContext, + *, + num_crops: Optional[int] = None): + model_config = ctx.model_config + processor = cached_get_processor(model_config.model) + image_seq_len = processor.image_seq_len + image_processor = processor.image_processor + + size = image_processor.size['longest_edge'] + max_image_size = image_processor.max_image_size['longest_edge'] + resized_height, resized_width = size, size + + grid_h = resized_height // max_image_size + grid_w = resized_width // max_image_size + + return (grid_h * grid_w + 1) * image_seq_len + + +def dummy_data_for_idefics3(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]) -> DummyData: + hf_config = ctx.get_hf_config() + num_images = mm_counts["image"] + + processor = cached_get_processor(ctx.model_config.model) + image_seq_len = processor.image_seq_len + max_llm_image_tokens = 17 * image_seq_len * num_images + + seq_data = SequenceData.from_prompt_token_counts( + (hf_config.image_token_id, max_llm_image_tokens), (0, seq_len)) + + width = height = hf_config.vision_config.image_size + image = Image.new("RGB", (width, height), color=0) + mm_data = {"image": [image] if num_images == 1 else [image] * num_images} + + return DummyData(seq_data, mm_data) + + +class Idefics3SimpleMLP(nn.Module): + + def __init__(self, config): + super().__init__() + input_size = config.vision_config.hidden_size * (config.scale_factor** + 2) + output_size = config.text_config.hidden_size + self.proj = ReplicatedLinear(input_size, output_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out, _ = self.proj(x) + return out + + +class Idefics3Connector(nn.Module): + + def __init__(self, config): + super().__init__() + self.scale_factor = config.scale_factor + self.modality_projection = Idefics3SimpleMLP(config) + + def pixel_shuffle(self, + x: torch.Tensor, + scale_factor: int = 2) -> torch.Tensor: + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), + embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), + embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor: + image_hidden_states = self.pixel_shuffle(image_hidden_states, + self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class Idefics3Model(nn.Module): + + def __init__( + self, + config: Idefics3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.padding_idx = self.config.text_config.pad_token_id + self.vocab_size = self.config.text_config.vocab_size + + self.vision_model = Idefics3VisionTransformer(config.vision_config, + quant_config) + self.connector = Idefics3Connector(config) + self.text_model = LlamaModel(config.text_config, cache_config, + quant_config) + + self.image_seq_len = int( + ((config.vision_config.image_size // + config.vision_config.patch_size)**2) / (config.scale_factor**2)) + self.image_token_id = self.config.image_token_id + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + rows = kwargs.pop("rows", None) + cols = kwargs.pop("cols", None) + pixel_attention_mask = kwargs.pop("pixel_attention_mask", None) + + if pixel_values is None and image_embeds is None: + return None + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return Idefics3ImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds, concat=True), + ) + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return Idefics3ImagePixelInputs(type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, + concat=True)), + rows=rows, + cols=cols, + pixel_attention_mask=flatten_bn( + pixel_attention_mask, + concat=True)) + + raise AssertionError("This line should be unreachable.") + + def _image_pixels_to_features( + self, + pixel_values: torch.Tensor, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + ) -> torch.Tensor: + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.to( + dtype=self.vision_model.embeddings.patch_embedding.weight.dtype + ) # fp16 compatibility + pixel_values = pixel_values.view(batch_size * num_images, + *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=(pixel_values.size(0), pixel_values.size(2), + pixel_values.size(3)), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:]) + pixel_attention_mask = pixel_attention_mask[ + real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, + size=patch_size, + step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, + size=patch_size, + step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + return image_hidden_states + + def _process_image_pixels( + self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor: + assert self.vision_model is not None + + pixel_values = inputs["data"] + pixel_attention_mask = inputs["pixel_attention_mask"] + + return self._image_pixels_to_features(pixel_values, + pixel_attention_mask) + + def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + image_features = self._process_image_pixels(image_input) + return self.connector(image_features) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.text_model.get_input_embeddings(input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) + else: + inputs_embeds = self.text_model.get_input_embeddings(input_ids) + input_ids = None + + hidden_states = self.text_model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_idefics3) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_idefics3_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_idefics3) +@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3) +class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__( + self, + config: Idefics3Config, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + self.model = Idefics3Model(config, cache_config, quant_config) + self.image_token_id = self.config.image_token_id + + self.lm_head = ParallelLMHead( + config.text_config.vocab_size, + config.text_config.hidden_size, + quant_config=quant_config, + ) + if self.config.text_config.tie_word_embeddings: + self.lm_head.weight = self.model.text_model.wte.weight + self.logits_processor = LogitsProcessor(config.text_config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + **kwargs, + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 792c6cec34ae0..32750602b988c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -120,6 +120,7 @@ "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), + "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501