From 56ac3a9f84ae648db1a2ae99a5489307f9844c7c Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 17 Jan 2025 00:29:49 +0800 Subject: [PATCH 1/9] port deepseek-vl2 processor Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/deepseek_vl2.py | 27 +- .../transformers_utils/processors/__init__.py | 3 + .../processors/deepseek_vl2.py | 358 ++++++++++++++++++ 3 files changed, 367 insertions(+), 21 deletions(-) create mode 100644 vllm/transformers_utils/processors/__init__.py create mode 100644 vllm/transformers_utils/processors/deepseek_vl2.py diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 99fa941c055d2..e64787fe06433 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -1,7 +1,7 @@ # adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py """Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" import math -from functools import cached_property, partial +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -9,7 +9,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import AutoProcessor, BatchFeature, ProcessorMixin +from transformers import BatchFeature from vllm.attention import AttentionMetadata from vllm.config import VllmConfig @@ -31,6 +31,8 @@ from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, MlpProjectorConfig, VisionEncoderConfig) +from vllm.transformers_utils.processors.deepseek_vl2 import ( + DeepseekVLV2Processor) from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP @@ -129,25 +131,8 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(DeepseekVLV2Config) - def get_hf_processor(self) -> ProcessorMixin: - # TODO(Isotr0py): we should get rid of dependency on deepseek_vl2 - # in the future, because it's flasky and lack of maintenance. - try: - from deepseek_vl2.models.processing_deepseek_vl_v2 import ( - DeepseekVLV2Processor, select_best_resolution) - AutoProcessor.register("DeepseekVLV2Processor", - DeepseekVLV2Processor) - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "You need to `pip install " - "git+https://github.com/deepseek-ai/DeepSeek-VL2.git` " - "to use this model") from exc - - processor = self.ctx.get_hf_processor(DeepseekVLV2Processor) - processor.select_best_resolution = partial( - select_best_resolution, - candidate_resolutions=processor.candidate_resolutions) - return processor + def get_hf_processor(self) -> DeepseekVLV2Processor: + return self.ctx.get_hf_processor(DeepseekVLV2Processor) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py new file mode 100644 index 0000000000000..a70032ddf27f5 --- /dev/null +++ b/vllm/transformers_utils/processors/__init__.py @@ -0,0 +1,3 @@ +from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor + +__all__ = ["DeepseekVLV2Processor"] \ No newline at end of file diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py new file mode 100644 index 0000000000000..a38688ec0966d --- /dev/null +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -0,0 +1,358 @@ +# adapted from +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import Dict, Tuple, List, Literal, Optional +import math + +import torch +from torch.nn.utils.rnn import pad_sequence +import torchvision.transforms as T +from transformers import LlamaTokenizerFast, BatchFeature, AutoProcessor +from transformers.processing_utils import ProcessorMixin +from PIL import Image, ImageOps + + +class ImageTransform(object): + def __init__( + self, + mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True + ): + self.mean = mean + self.std = std + self.normalize = normalize + + transform_pipelines = [ + T.ToTensor() + ] + + if normalize: + transform_pipelines.append(T.Normalize(mean, std)) + + self.transform = T.Compose(transform_pipelines) + + def __call__(self, pil_img: Image.Image): + x = self.transform(pil_img) + return x + + + +class DeepseekVLV2Processor(ProcessorMixin): + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + attributes = ["tokenizer"] + + def __init__( + self, + tokenizer: LlamaTokenizerFast, + candidate_resolutions: Tuple[Tuple[int, int]], + patch_size: int, + downsample_ratio: int, + image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + image_token: str = "", + pad_token: str = "<|▁pad▁|>", + add_special_token: bool = False, + ignore_id: int = -100, + **kwargs, + ): + + self.candidate_resolutions = candidate_resolutions + self.image_size = candidate_resolutions[0][0] + self.patch_size = patch_size + self.image_mean = image_mean + self.image_std = image_std + self.normalize = normalize + self.downsample_ratio = downsample_ratio + + self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize) + self.tokenizer = tokenizer + self.tokenizer.padding_side = 'left' # must set this,padding side with make a difference in batch inference + + # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id' + if tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({'pad_token': pad_token}) + + # add image token + image_token_id = self.tokenizer.vocab.get(image_token) + if image_token_id is None: + special_tokens = [image_token] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + self.image_token_id = self.tokenizer.vocab.get(image_token) + + # add five special tokens for grounding-related tasks + # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|> + special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>'] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + + # add special tokens for SFT data + special_tokens = ["<|User|>", "<|Assistant|>"] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + + self.image_token = image_token + self.pad_token = pad_token + self.add_special_token = add_special_token + self.ignore_id = ignore_id + + super().__init__( + tokenizer, + **kwargs, + ) + + def select_best_resolution(self, image_size): + # used for cropping + original_width, original_height = image_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for width, height in self.candidate_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def pad_id(self): + return self.tokenizer.pad_token_id + + def encode(self, text: str, bos: bool = True, eos: bool = False): + t = self.tokenizer.encode(text, add_special_tokens=False) + + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + + return t + + def decode(self, t: List[int], **kwargs) -> str: + return self.tokenizer.decode(t, **kwargs) + + def process_one( + self, + prompt: str, + images: List[Image.Image], + inference_mode: bool = True, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + inference_mode (bool): if True, then remove the last eos token; + system_prompt (str): the system prompt; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + assert ( + prompt is not None and images is not None + ), "prompt and images must be used at the same time." + + sft_format = prompt + tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens = self.tokenize_with_images( + sft_format, images, bos=True, eos=True, cropping=len(images) <= 2) + masked_tokenized_str = [] + for token_index in tokenized_str: + if token_index != self.image_token_id: + masked_tokenized_str.append(token_index) + else: + masked_tokenized_str.append(self.ignore_id) + + assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \ + (f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " + f"imags_seq_mask's length {len(images_seq_mask)}, are not equal") + + input_ids = torch.LongTensor(tokenized_str) + target_ids = torch.LongTensor(masked_tokenized_str) + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) + + # set input_ids < 0 | input_ids == self.image_token_id as ignore_id + target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = self.ignore_id + input_ids[input_ids < 0] = self.pad_id + + if inference_mode: + # 去掉结尾的eos token + assert input_ids[-1] == self.eos_id + input_ids = input_ids[:-1] + target_ids = target_ids[:-1] + images_seq_mask = images_seq_mask[:-1] + + if len(images_list) == 0: + images = torch.zeros((1, 3, self.image_size, self.image_size)) + images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) + else: + images = torch.stack(images_list, dim=0) + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + + prepare = BatchFeature( + sft_format=sft_format, + input_ids=input_ids, + target_ids=target_ids, + images=images, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + num_image_tokens=num_image_tokens, + tensor_type="pt", + ) + + return prepare + + def __call__( + self, + *, + prompt: str, + images: List[Image.Image], + inference_mode: bool = True, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + force_batchify (bool): force batchify the inputs; + inference_mode (bool): if True, then remove the last eos token; + system_prompt (str): the system prompt; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + prepare = self.process_one( + prompt=prompt, + images=images, + inference_mode=inference_mode, + ) + + return prepare + + def tokenize_with_images( + self, + conversation: str, + images: List[Image.Image], + bos: bool = True, + eos: bool = True, + cropping: bool = True, + ): + """Tokenize text with tags.""" + assert conversation.count(self.image_token) == len(images) + text_splits = conversation.split(self.image_token) + images_list, images_seq_mask, images_spatial_crop = [], [], [] + num_image_tokens = [] + tokenized_str = [] + for text_sep, image in zip(text_splits, images): + """encode text_sep""" + tokenized_sep = self.encode(text_sep, bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """select best resolution for anyres""" + if cropping: + best_width, best_height = self.select_best_resolution(image.size) + else: + best_width, best_height = self.image_size, self.image_size + + """process the global view""" + global_view = ImageOps.pad(image, (self.image_size, self.image_size), + color=tuple(int(x * 255) for x in self.image_transform.mean)) + images_list.append(self.image_transform(global_view)) + + """process the local views""" + local_view = ImageOps.pad(image, (best_width, best_height), + color=tuple(int(x * 255) for x in self.image_transform.mean)) + for i in range(0, best_height, self.image_size): + for j in range(0, best_width, self.image_size): + images_list.append( + self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size)))) + + """record height / width crop num""" + num_width_tiles, num_height_tiles = best_width // self.image_size, best_height // self.image_size + images_spatial_crop.append([num_width_tiles, num_height_tiles]) + + """add image tokens""" + h = w = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) + # global views tokens h * (w + 1), 1 is for line seperator + tokenized_image = [self.image_token_id] * h * (w + 1) + # add a seperator between global and local views + tokenized_image += [self.image_token_id] + # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1) + tokenized_image += [self.image_token_id] * (num_height_tiles * h) * (num_width_tiles * w + 1) + + tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) + num_image_tokens.append(len(tokenized_image)) + + """process the last text split""" + tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """add the bos and eos tokens""" + if bos: + tokenized_str = [self.bos_id] + tokenized_str + images_seq_mask = [False] + images_seq_mask + if eos: + tokenized_str = tokenized_str + [self.eos_id] + images_seq_mask = images_seq_mask + [False] + + assert len(tokenized_str) == len( + images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" + + return tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens + + +AutoProcessor.register("DeepseekVLV2Processor", + DeepseekVLV2Processor) \ No newline at end of file From 307f802580a28faac8bf44d6d4714dc36c574d90 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 17 Jan 2025 19:17:11 +0800 Subject: [PATCH 2/9] fix ported processor Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/deepseek_vl2.py | 35 +---- .../transformers_utils/processors/__init__.py | 5 +- .../processors/deepseek_vl2.py | 141 +++++++++--------- 3 files changed, 81 insertions(+), 100 deletions(-) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index a6a77c235409b..c7ab5e01cd226 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -209,31 +209,15 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], ) -> BatchFeature: if mm_data: - outputs = self.info.ctx.call_hf_processor( + processed_outputs = self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(prompt=prompt, **mm_data), mm_kwargs, ) - - # Deepseek-vl2 processor don't return BatchFeature, - # we need to manually create it - processed_outputs = dict(input_ids=outputs["input_ids"]) - processed_outputs = BatchFeature(data=dict(processed_outputs), - tensor_type="pt") - - # Remove batch dimension from processor outputs, - # because we will try batch to create NestedTensors target_dtype = self.info.ctx.model_config.dtype - pixel_values = outputs["images"].to(target_dtype).squeeze(0) - images_spatial_crop = outputs["images_spatial_crop"].squeeze(0) - patches_per_image = [ - x.prod().item() + 1 for x in images_spatial_crop - ] - - # Rename `images` -> `pixel_values` to avoid confusion - processed_outputs["pixel_values"] = list( - pixel_values.split(patches_per_image)) - processed_outputs["images_spatial_crop"] = images_spatial_crop + processed_outputs["pixel_values"] = ( + processed_outputs["pixel_values"].unsqueeze(0).to(target_dtype) + ) else: tokenizer = self.info.get_tokenizer() processed_outputs = tokenizer(prompt, @@ -341,18 +325,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" ) - if self.text_config.topk_method == "noaux_tc": - architectures = ["DeepseekV3ForCausalLM"] - elif not self.text_config.use_mla: - architectures = ["DeepseekForCausalLM"] - else: - architectures = ["DeepseekV2ForCausalLM"] - self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=self.text_config, prefix=maybe_prefix(prefix, "language"), - architectures=architectures, + architectures=["DeepseekV3ForCausalLM"] + if self.text_config.topk_method == "noaux_tc" else + ["DeepseekV2ForCausalLM"], ) self.make_empty_intermediate_tensors = ( diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index a70032ddf27f5..9c71b8cada32e 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -1,3 +1,4 @@ -from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor +from vllm.transformers_utils.processors.deepseek_vl2 import ( + DeepseekVLV2Processor) -__all__ = ["DeepseekVLV2Processor"] \ No newline at end of file +__all__ = ["DeepseekVLV2Processor"] diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py index a38688ec0966d..8a5b850aa4506 100644 --- a/vllm/transformers_utils/processors/deepseek_vl2.py +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -1,4 +1,7 @@ -# adapted from +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/ff23960c5cf9e6874b44be38af930cfb0ccbb620/deepseek_vl2/models/processing_deepseek_vl_v2.py # Copyright (c) 2023-2024 DeepSeek. # # Permission is hereby granted, free of charge, to any person obtaining a copy of @@ -18,31 +21,27 @@ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -from typing import Dict, Tuple, List, Literal, Optional import math +from typing import List, Tuple import torch -from torch.nn.utils.rnn import pad_sequence import torchvision.transforms as T -from transformers import LlamaTokenizerFast, BatchFeature, AutoProcessor -from transformers.processing_utils import ProcessorMixin from PIL import Image, ImageOps +from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast +from transformers.processing_utils import ProcessorMixin -class ImageTransform(object): - def __init__( - self, - mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), - std: Tuple[float, float, float] = (0.5, 0.5, 0.5), - normalize: bool = True - ): +class ImageTransform: + + def __init__(self, + mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True): self.mean = mean self.std = std self.normalize = normalize - transform_pipelines = [ - T.ToTensor() - ] + transform_pipelines = [T.ToTensor()] if normalize: transform_pipelines.append(T.Normalize(mean, std)) @@ -54,25 +53,26 @@ def __call__(self, pil_img: Image.Image): return x - class DeepseekVLV2Processor(ProcessorMixin): tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") attributes = ["tokenizer"] def __init__( - self, - tokenizer: LlamaTokenizerFast, - candidate_resolutions: Tuple[Tuple[int, int]], - patch_size: int, - downsample_ratio: int, - image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), - image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5), - normalize: bool = True, - image_token: str = "", - pad_token: str = "<|▁pad▁|>", - add_special_token: bool = False, - ignore_id: int = -100, - **kwargs, + self, + tokenizer: LlamaTokenizerFast, + candidate_resolutions: Tuple[Tuple[int, int]], + patch_size: int, + downsample_ratio: int, + image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), + image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + image_token: str = "", + pad_token: str = "<|▁pad▁|>", + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, ): self.candidate_resolutions = candidate_resolutions @@ -113,13 +113,15 @@ def __init__( self.image_token = image_token self.pad_token = pad_token self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt self.ignore_id = ignore_id super().__init__( tokenizer, **kwargs, ) - + def select_best_resolution(self, image_size): # used for cropping original_width, original_height = image_size @@ -129,11 +131,15 @@ def select_best_resolution(self, image_size): for width, height in self.candidate_resolutions: scale = min(width / original_width, height / original_height) - downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) - effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + downscaled_width, downscaled_height = int( + original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, + original_width * original_height) wasted_resolution = (width * height) - effective_resolution - if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) @@ -166,11 +172,11 @@ def decode(self, t: List[int], **kwargs) -> str: return self.tokenizer.decode(t, **kwargs) def process_one( - self, - prompt: str, - images: List[Image.Image], - inference_mode: bool = True, - **kwargs, + self, + prompt: str, + images: List[Image.Image], + inference_mode: bool = True, + **kwargs, ): """ @@ -191,9 +197,8 @@ def process_one( - num_image_tokens (List[int]): the number of image tokens """ - assert ( - prompt is not None and images is not None - ), "prompt and images must be used at the same time." + assert (prompt is not None and images is not None + ), "prompt and images must be used at the same time." sft_format = prompt tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens = self.tokenize_with_images( @@ -214,7 +219,8 @@ def process_one( images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) # set input_ids < 0 | input_ids == self.image_token_id as ignore_id - target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = self.ignore_id + target_ids[(input_ids < 0) | + (input_ids == self.image_token_id)] = self.ignore_id input_ids[input_ids < 0] = self.pad_id if inference_mode: @@ -225,42 +231,38 @@ def process_one( images_seq_mask = images_seq_mask[:-1] if len(images_list) == 0: - images = torch.zeros((1, 3, self.image_size, self.image_size)) + pixel_values = torch.zeros((1, 3, self.image_size, self.image_size)) images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) else: - images = torch.stack(images_list, dim=0) + pixel_values = torch.stack(images_list, dim=0) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) - prepare = BatchFeature( - sft_format=sft_format, + input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) + pixel_values = pixel_values.unsqueeze(0) + + prepare = BatchFeature(data=dict( input_ids=input_ids, - target_ids=target_ids, - images=images, + pixel_values=images, images_seq_mask=images_seq_mask, images_spatial_crop=images_spatial_crop, num_image_tokens=num_image_tokens, - tensor_type="pt", - ) - + )) return prepare def __call__( - self, - *, - prompt: str, - images: List[Image.Image], - inference_mode: bool = True, - **kwargs, + self, + *, + prompt: str, + images: List[Image.Image], + inference_mode: bool = True, + **kwargs, ): """ Args: prompt (str): the formatted prompt; - conversations (List[Dict]): conversations with a list of messages; images (List[ImageType]): the list of images; - force_batchify (bool): force batchify the inputs; inference_mode (bool): if True, then remove the last eos token; - system_prompt (str): the system prompt; **kwargs: Returns: @@ -280,12 +282,12 @@ def __call__( return prepare def tokenize_with_images( - self, - conversation: str, - images: List[Image.Image], - bos: bool = True, - eos: bool = True, - cropping: bool = True, + self, + conversation: str, + images: List[Image.Image], + bos: bool = True, + eos: bool = True, + cropping: bool = True, ): """Tokenize text with tags.""" assert conversation.count(self.image_token) == len(images) @@ -324,9 +326,9 @@ def tokenize_with_images( """add image tokens""" h = w = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) - # global views tokens h * (w + 1), 1 is for line seperator + # global views tokens h * (w + 1), 1 is for line separator tokenized_image = [self.image_token_id] * h * (w + 1) - # add a seperator between global and local views + # add a separator between global and local views tokenized_image += [self.image_token_id] # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1) tokenized_image += [self.image_token_id] * (num_height_tiles * h) * (num_width_tiles * w + 1) @@ -354,5 +356,4 @@ def tokenize_with_images( return tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens -AutoProcessor.register("DeepseekVLV2Processor", - DeepseekVLV2Processor) \ No newline at end of file +AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor) From 049647fc80ce363a0ba67f8191348da2e7385c49 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 17 Jan 2025 19:49:56 +0800 Subject: [PATCH 3/9] revert Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/deepseek_vl2.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index c7ab5e01cd226..8037181ae72a2 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -324,14 +324,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): raise ValueError( f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" ) + + if self.text_config.topk_method == "noaux_tc": + architectures = ["DeepseekV3ForCausalLM"] + elif not self.text_config.use_mla: + architectures = ["DeepseekForCausalLM"] + else: + architectures = ["DeepseekV2ForCausalLM"] self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=self.text_config, prefix=maybe_prefix(prefix, "language"), - architectures=["DeepseekV3ForCausalLM"] - if self.text_config.topk_method == "noaux_tc" else - ["DeepseekV2ForCausalLM"], + architectures=architectures, ) self.make_empty_intermediate_tensors = ( From 68d19a9ae259981c0cf779c812eeaeeee3422e7e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 17 Jan 2025 19:53:51 +0800 Subject: [PATCH 4/9] remove dependency Signed-off-by: Isotr0py <2037008807@qq.com> --- .buildkite/test-pipeline.yaml | 1 - docs/source/models/supported_models.md | 10 ++-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7442de245bd80..64cfda838adaf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -52,7 +52,6 @@ steps: - tests/worker - tests/standalone_tests/lazy_torch_compile.py commands: - - pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test - python3 standalone_tests/lazy_torch_compile.py - pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s async_engine # AsyncLLMEngine diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index d07cde3db5c6e..2edb610ddf959 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -767,16 +767,10 @@ See [this page](#generative-models) for more information on how to use generativ E Pre-computed embeddings can be inputted for this modality. + Multiple items can be inputted per text prompt for this modality. -````{note} -To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package: - -```shell -pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git +```{note} +To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM. ``` -Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM. -```` - ```{note} To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. ``` From 3f86091d1ab6a1f7085c80080f96d99f18b84a74 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 17 Jan 2025 20:49:43 +0800 Subject: [PATCH 5/9] fix pixel_values shape Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/deepseek_vl2.py | 2 +- .../processors/deepseek_vl2.py | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 8037181ae72a2..2f8999ebea6cd 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -216,7 +216,7 @@ def _call_hf_processor( ) target_dtype = self.info.ctx.model_config.dtype processed_outputs["pixel_values"] = ( - processed_outputs["pixel_values"].unsqueeze(0).to(target_dtype) + processed_outputs["pixel_values"].to(target_dtype) ) else: tokenizer = self.info.get_tokenizer() diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py index 8a5b850aa4506..99ace3e1b3b7c 100644 --- a/vllm/transformers_utils/processors/deepseek_vl2.py +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -240,13 +240,16 @@ def process_one( input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) pixel_values = pixel_values.unsqueeze(0) - prepare = BatchFeature(data=dict( - input_ids=input_ids, - pixel_values=images, - images_seq_mask=images_seq_mask, - images_spatial_crop=images_spatial_crop, - num_image_tokens=num_image_tokens, - )) + prepare = BatchFeature( + data=dict( + input_ids=input_ids, + pixel_values=pixel_values, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + num_image_tokens=num_image_tokens, + ), + tensor_type="pt", + ) return prepare def __call__( From 3ade1ece7249e933e63140475e63e1479f19135a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 17 Jan 2025 20:53:00 +0800 Subject: [PATCH 6/9] fix example model type Signed-off-by: Isotr0py <2037008807@qq.com> --- examples/offline_inference/vision_language_multi_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index cf3c5dd4e0a2c..43c44fa867e0a 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -393,7 +393,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: model_example_map = { "aria": load_aria, - "deepseek_vl2": load_deepseek_vl2, + "deepseek_vl_v2": load_deepseek_vl2, "h2ovl_chat": load_h2onvl, "idefics3": load_idefics3, "internvl_chat": load_internvl, From a02ca6a8cb8cee2a67b58e283cae4efdc746d836 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 17 Jan 2025 21:08:48 +0800 Subject: [PATCH 7/9] fix multi-images processing Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/deepseek_vl2.py | 14 ++++++++++---- vllm/transformers_utils/processors/deepseek_vl2.py | 7 +++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 2f8999ebea6cd..4d3d1c329a2c0 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -215,9 +215,15 @@ def _call_hf_processor( mm_kwargs, ) target_dtype = self.info.ctx.model_config.dtype - processed_outputs["pixel_values"] = ( - processed_outputs["pixel_values"].to(target_dtype) - ) + pixel_values = processed_outputs.pop("pixel_values").to( + target_dtype) + # split pixel values into patches corresponding to each image + images_spatial_crop = processed_outputs["images_spatial_crop"] + patches_per_image = [ + x.prod().item() + 1 for x in images_spatial_crop + ] + pixel_values = pixel_values.split(patches_per_image) + processed_outputs["pixel_values"] = pixel_values else: tokenizer = self.info.get_tokenizer() processed_outputs = tokenizer(prompt, @@ -324,7 +330,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): raise ValueError( f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" ) - + if self.text_config.topk_method == "noaux_tc": architectures = ["DeepseekV3ForCausalLM"] elif not self.text_config.use_mla: diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py index 99ace3e1b3b7c..27cdf6bc22d0e 100644 --- a/vllm/transformers_utils/processors/deepseek_vl2.py +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -192,7 +192,7 @@ def process_one( outputs (BaseProcessorOutput): the output of the processor, - input_ids (torch.LongTensor): [N + image tokens] - target_ids (torch.LongTensor): [N + image tokens] - - images (torch.FloatTensor): [n_images, 3, H, W] + - pixel_values (torch.FloatTensor): [n_patches, 3, H, W] - image_id (int): the id of the image token - num_image_tokens (List[int]): the number of image tokens """ @@ -237,8 +237,7 @@ def process_one( pixel_values = torch.stack(images_list, dim=0) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) - input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) - pixel_values = pixel_values.unsqueeze(0) + input_ids = input_ids.unsqueeze(0) prepare = BatchFeature( data=dict( @@ -247,7 +246,7 @@ def process_one( images_seq_mask=images_seq_mask, images_spatial_crop=images_spatial_crop, num_image_tokens=num_image_tokens, - ), + ), tensor_type="pt", ) return prepare From cbd4ab40b8a500d5415331533357f1c5e616000d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 17 Jan 2025 21:22:57 +0800 Subject: [PATCH 8/9] add processor test Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/multimodal/processing/test_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 0a38779e0e4f0..1e3e7ea50b122 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -22,6 +22,8 @@ def _test_processing_correctness( ): if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3": hf_overrides = {"architectures": ["MantisForConditionalGeneration"]} + elif model_id == "deepseek-ai/deepseek-vl2-tiny": + hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]} else: hf_overrides = {} @@ -139,6 +141,7 @@ def _test_processing_correctness( ("rhymes-ai/Aria", {"image": True}), ("Salesforce/blip2-opt-2.7b", {"image": False}), ("facebook/chameleon-7b", {"image": False}), + ("deepseek-ai/deepseek-vl2-tiny", {"image": True}), ("adept/fuyu-8b", {"image": False}), ("llava-hf/llava-1.5-7b-hf", {"image": True}), ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}), From 000bbfd6a5408bd9d47bf18f83d553eec91fcefa Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 17 Jan 2025 22:44:15 +0800 Subject: [PATCH 9/9] use model repo with dynamic module Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/decoder_only/vision_language/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 5710303548c34..ca572cc39e538 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -190,7 +190,7 @@ dtype="bfloat16", ), "deepseek_vl_v2": VLMTestInfo( - models=["deepseek-ai/deepseek-vl2-tiny"], + models=["Isotr0py/deepseek-vl2-tiny"], # model repo using dynamic module test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501 max_model_len=4096,