Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Port deepseek-vl2 processor and remove deepseek_vl2 dependency #12169

Merged
merged 11 commits into from
Jan 18, 2025
1 change: 0 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ steps:
- tests/worker
- tests/standalone_tests/lazy_torch_compile.py
commands:
- pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test
- python3 standalone_tests/lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
Expand Down
10 changes: 2 additions & 8 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -767,16 +767,10 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.

````{note}
To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package:

```shell
pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git
```{note}
To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
```

Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
````

```{note}
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
```
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:

model_example_map = {
"aria": load_aria,
"deepseek_vl2": load_deepseek_vl2,
"deepseek_vl_v2": load_deepseek_vl2,
"h2ovl_chat": load_h2onvl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
dtype="bfloat16",
),
"deepseek_vl_v2": VLMTestInfo(
models=["deepseek-ai/deepseek-vl2-tiny"],
models=["Isotr0py/deepseek-vl2-tiny"], # model repo using dynamic module
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501
max_model_len=4096,
Expand Down
3 changes: 3 additions & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def _test_processing_correctness(
):
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
elif model_id == "deepseek-ai/deepseek-vl2-tiny":
hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]}
else:
hf_overrides = {}

Expand Down Expand Up @@ -139,6 +141,7 @@ def _test_processing_correctness(
("rhymes-ai/Aria", {"image": True}),
("Salesforce/blip2-opt-2.7b", {"image": False}),
("facebook/chameleon-7b", {"image": False}),
("deepseek-ai/deepseek-vl2-tiny", {"image": True}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
Expand Down
51 changes: 13 additions & 38 deletions vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math
from functools import cached_property, partial
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import AutoProcessor, BatchFeature, ProcessorMixin
from transformers import BatchFeature

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
Expand All @@ -31,6 +31,8 @@
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
MlpProjectorConfig,
VisionEncoderConfig)
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP
Expand Down Expand Up @@ -129,25 +131,8 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(DeepseekVLV2Config)

def get_hf_processor(self) -> ProcessorMixin:
# TODO(Isotr0py): we should get rid of dependency on deepseek_vl2
# in the future, because it's flasky and lack of maintenance.
try:
from deepseek_vl2.models.processing_deepseek_vl_v2 import (
DeepseekVLV2Processor, select_best_resolution)
AutoProcessor.register("DeepseekVLV2Processor",
DeepseekVLV2Processor)
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/deepseek-ai/DeepSeek-VL2.git` "
"to use this model") from exc

processor = self.ctx.get_hf_processor(DeepseekVLV2Processor)
processor.select_best_resolution = partial(
select_best_resolution,
candidate_resolutions=processor.candidate_resolutions)
return processor
def get_hf_processor(self) -> DeepseekVLV2Processor:
return self.ctx.get_hf_processor(DeepseekVLV2Processor)

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
Expand Down Expand Up @@ -224,31 +209,21 @@ def _call_hf_processor(
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
outputs = self.info.ctx.call_hf_processor(
processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(prompt=prompt, **mm_data),
mm_kwargs,
)

# Deepseek-vl2 processor don't return BatchFeature,
# we need to manually create it
processed_outputs = dict(input_ids=outputs["input_ids"])
processed_outputs = BatchFeature(data=dict(processed_outputs),
tensor_type="pt")

# Remove batch dimension from processor outputs,
# because we will try batch to create NestedTensors
target_dtype = self.info.ctx.model_config.dtype
pixel_values = outputs["images"].to(target_dtype).squeeze(0)
images_spatial_crop = outputs["images_spatial_crop"].squeeze(0)
pixel_values = processed_outputs.pop("pixel_values").to(
target_dtype)
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [
x.prod().item() + 1 for x in images_spatial_crop
]

# Rename `images` -> `pixel_values` to avoid confusion
processed_outputs["pixel_values"] = list(
pixel_values.split(patches_per_image))
processed_outputs["images_spatial_crop"] = images_spatial_crop
pixel_values = pixel_values.split(patches_per_image)
processed_outputs["pixel_values"] = pixel_values
else:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(prompt,
Expand Down
4 changes: 4 additions & 0 deletions vllm/transformers_utils/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)

__all__ = ["DeepseekVLV2Processor"]
Loading
Loading