Skip to content

Commit

Permalink
[Model] Port deepseek-vl2 processor, remove dependency (vllm-project#…
Browse files Browse the repository at this point in the history
…12169)

Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py authored and lckr committed Jan 19, 2025
1 parent 59b265c commit 5428987
Show file tree
Hide file tree
Showing 8 changed files with 385 additions and 49 deletions.
1 change: 0 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ steps:
- tests/worker
- tests/standalone_tests/lazy_torch_compile.py
commands:
- pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test
- python3 standalone_tests/lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
Expand Down
10 changes: 2 additions & 8 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -767,16 +767,10 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.

````{note}
To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package:
```shell
pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git
```{note}
To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
```

Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
````

```{note}
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
```
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:

model_example_map = {
"aria": load_aria,
"deepseek_vl2": load_deepseek_vl2,
"deepseek_vl_v2": load_deepseek_vl2,
"h2ovl_chat": load_h2onvl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
dtype="bfloat16",
),
"deepseek_vl_v2": VLMTestInfo(
models=["deepseek-ai/deepseek-vl2-tiny"],
models=["Isotr0py/deepseek-vl2-tiny"], # model repo using dynamic module
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501
max_model_len=4096,
Expand Down
3 changes: 3 additions & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def _test_processing_correctness(
):
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
elif model_id == "deepseek-ai/deepseek-vl2-tiny":
hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]}
else:
hf_overrides = {}

Expand Down Expand Up @@ -139,6 +141,7 @@ def _test_processing_correctness(
("rhymes-ai/Aria", {"image": True}),
("Salesforce/blip2-opt-2.7b", {"image": False}),
("facebook/chameleon-7b", {"image": False}),
("deepseek-ai/deepseek-vl2-tiny", {"image": True}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
Expand Down
51 changes: 13 additions & 38 deletions vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math
from functools import cached_property, partial
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import AutoProcessor, BatchFeature, ProcessorMixin
from transformers import BatchFeature

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
Expand All @@ -31,6 +31,8 @@
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
MlpProjectorConfig,
VisionEncoderConfig)
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)
from vllm.utils import is_list_of

from .interfaces import SupportsMultiModal, SupportsPP
Expand Down Expand Up @@ -129,25 +131,8 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(DeepseekVLV2Config)

def get_hf_processor(self) -> ProcessorMixin:
# TODO(Isotr0py): we should get rid of dependency on deepseek_vl2
# in the future, because it's flasky and lack of maintenance.
try:
from deepseek_vl2.models.processing_deepseek_vl_v2 import (
DeepseekVLV2Processor, select_best_resolution)
AutoProcessor.register("DeepseekVLV2Processor",
DeepseekVLV2Processor)
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/deepseek-ai/DeepSeek-VL2.git` "
"to use this model") from exc

processor = self.ctx.get_hf_processor(DeepseekVLV2Processor)
processor.select_best_resolution = partial(
select_best_resolution,
candidate_resolutions=processor.candidate_resolutions)
return processor
def get_hf_processor(self) -> DeepseekVLV2Processor:
return self.ctx.get_hf_processor(DeepseekVLV2Processor)

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
Expand Down Expand Up @@ -224,31 +209,21 @@ def _call_hf_processor(
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
outputs = self.info.ctx.call_hf_processor(
processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(prompt=prompt, **mm_data),
mm_kwargs,
)

# Deepseek-vl2 processor don't return BatchFeature,
# we need to manually create it
processed_outputs = dict(input_ids=outputs["input_ids"])
processed_outputs = BatchFeature(data=dict(processed_outputs),
tensor_type="pt")

# Remove batch dimension from processor outputs,
# because we will try batch to create NestedTensors
target_dtype = self.info.ctx.model_config.dtype
pixel_values = outputs["images"].to(target_dtype).squeeze(0)
images_spatial_crop = outputs["images_spatial_crop"].squeeze(0)
pixel_values = processed_outputs.pop("pixel_values").to(
target_dtype)
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [
x.prod().item() + 1 for x in images_spatial_crop
]

# Rename `images` -> `pixel_values` to avoid confusion
processed_outputs["pixel_values"] = list(
pixel_values.split(patches_per_image))
processed_outputs["images_spatial_crop"] = images_spatial_crop
pixel_values = pixel_values.split(patches_per_image)
processed_outputs["pixel_values"] = pixel_values
else:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(prompt,
Expand Down
4 changes: 4 additions & 0 deletions vllm/transformers_utils/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)

__all__ = ["DeepseekVLV2Processor"]
Loading

0 comments on commit 5428987

Please sign in to comment.