Skip to content

Commit

Permalink
[Misc] Multi-Modality Support for Loading Local image Files
Browse files Browse the repository at this point in the history
FIX #8730

Signed-off-by: chaunceyjiang <[email protected]>
  • Loading branch information
chaunceyjiang committed Nov 4, 2024
1 parent b67feb1 commit 8c3572b
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 34 deletions.
12 changes: 8 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class ModelConfig:
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
Expand Down Expand Up @@ -132,6 +136,7 @@ def __init__(
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
allowed_local_media_path: str,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
Expand Down Expand Up @@ -164,6 +169,7 @@ def __init__(
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.allowed_local_media_path = allowed_local_media_path
self.seed = seed
self.revision = revision
self.code_revision = code_revision
Expand Down Expand Up @@ -1386,10 +1392,8 @@ def maybe_create_spec_config(
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
draft_token_acceptance_method=draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
typical_acceptance_sampler_posterior_threshold=typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
)
Expand Down
32 changes: 18 additions & 14 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class EngineArgs:
tokenizer_mode: str = 'auto'
chat_template_text_format: str = 'string'
trust_remote_code: bool = False
allowed_local_media_path: str = ""
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: ConfigFormat = ConfigFormat.AUTO
Expand Down Expand Up @@ -269,6 +270,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
parser.add_argument(
'--allowed-local-media-path',
type=str,
help="Allowing API requests to read local images or videos"
"from directories specified by the server file system."
"This is a security risk."
"Should only be enabled in trusted environments")
parser.add_argument('--download-dir',
type=nullable_str,
default=EngineArgs.download_dir,
Expand Down Expand Up @@ -658,8 +666,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'--speculative-model',
type=nullable_str,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
help='The name of the draft model to be used in speculative decoding.')

Check failure on line 669 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/engine/arg_utils.py:669:81: E501 Line too long (83 > 80)

Check failure on line 669 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/engine/arg_utils.py:669:81: E501 Line too long (83 > 80)
# Quantization settings for speculative model.
parser.add_argument(
'--speculative-model-quantization',
Expand All @@ -681,8 +688,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
'--speculative-disable-mqa-scorer',
action='store_true',
help=
'If set to True, the MQA scorer will be disabled in speculative '
help='If set to True, the MQA scorer will be disabled in speculative '

Check failure on line 691 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/engine/arg_utils.py:691:81: E501 Line too long (82 > 80)

Check failure on line 691 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/engine/arg_utils.py:691:81: E501 Line too long (82 > 80)
' and fall back to batch expansion')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
Expand Down Expand Up @@ -920,6 +926,7 @@ def create_model_config(self) -> ModelConfig:
tokenizer_mode=self.tokenizer_mode,
chat_template_text_format=self.chat_template_text_format,
trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
Expand Down Expand Up @@ -971,8 +978,8 @@ def create_engine_config(self) -> VllmConfig:
f"'bitsandbytes' load format, but got {self.load_format}")

if (self.load_format == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
Expand Down Expand Up @@ -1059,10 +1066,8 @@ def create_engine_config(self) -> VllmConfig:
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_model_quantization = \
self.speculative_model_quantization,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
speculative_model_quantization=self.speculative_model_quantization,
speculative_draft_tensor_parallel_size=self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
speculative_disable_by_batch_size=self.
Expand All @@ -1072,8 +1077,7 @@ def create_engine_config(self) -> VllmConfig:
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
draft_token_acceptance_method=self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
Expand Down Expand Up @@ -1136,7 +1140,7 @@ def create_engine_config(self) -> VllmConfig:
and self.max_cpu_loras > 0 else None) if self.enable_lora else None

if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[
Expand All @@ -1147,7 +1151,7 @@ def create_engine_config(self) -> VllmConfig:
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
if self.enable_prompt_adapter else None

decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
Expand Down
13 changes: 9 additions & 4 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
Example:
{
"image_url": "https://example.com/image.jpg"
Expand All @@ -73,7 +73,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):

class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"audio_url": "https://example.com/audio.mp3"
Expand Down Expand Up @@ -307,7 +307,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:
self._tracker = tracker

def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url)
image = get_and_parse_image(
image_url,
allowed_local_media_path=self._tracker._model_config.allowed_local_media_path)

placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
Expand All @@ -327,7 +329,10 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
self._tracker = tracker

def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(image_url)
image_coro = async_get_and_parse_image(
image_url,
allowed_local_media_path=self._tracker._model_config.allowed_local_media_path
)

placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
Expand Down
11 changes: 9 additions & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class LLM:
from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images
or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
Expand Down Expand Up @@ -139,6 +143,7 @@ def __init__(
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
Expand Down Expand Up @@ -179,6 +184,7 @@ def __init__(
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
Expand Down Expand Up @@ -475,7 +481,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
Expand Down Expand Up @@ -931,7 +937,8 @@ def _run_engine(
# Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
in_spd = total_in_toks / \
pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs)
out_spd = (total_out_toks /
Expand Down
57 changes: 47 additions & 10 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import lru_cache
from io import BytesIO
from typing import Any, List, Optional, Tuple, TypeVar, Union
from os import path

import numpy as np
import numpy.typing as npt
Expand All @@ -18,19 +19,49 @@
cached_get_tokenizer = lru_cache(get_tokenizer)


def _load_image_from_bytes(b: bytes):
def _load_image_from_bytes(b: bytes) -> Image.Image:
image = Image.open(BytesIO(b))
image.load()
return image


def _load_image_from_data_url(image_url: str):
def _is_subpath(image_path, allowed_local_media_path):
# Get the common path
common_path = path.commonpath([image_path, allowed_local_media_path])
# Check if the common path is the same as allowed_local_media_path
return common_path == path.abspath(allowed_local_media_path)


def _load_image_from_file(image_url: str, allowed_local_media_path: str) -> Image.Image:

Check failure on line 35 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/multimodal/utils.py:35:81: E501 Line too long (88 > 80)

Check failure on line 35 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/multimodal/utils.py:35:81: E501 Line too long (88 > 80)
if not allowed_local_media_path:
raise ValueError(
"Invalid 'image_url': Cannot load local files without '--allowed-local-media-path'.")

Check failure on line 38 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/multimodal/utils.py:38:81: E501 Line too long (97 > 80)

Check failure on line 38 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/multimodal/utils.py:38:81: E501 Line too long (97 > 80)
if allowed_local_media_path:
if not path.exists(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': The path does not exist.")

Check failure on line 42 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/multimodal/utils.py:42:81: E501 Line too long (81 > 80)

Check failure on line 42 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/multimodal/utils.py:42:81: E501 Line too long (81 > 80)
if not path.isdir(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': The path must be a directory.")

Check failure on line 45 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/multimodal/utils.py:45:81: E501 Line too long (86 > 80)

Check failure on line 45 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/multimodal/utils.py:45:81: E501 Line too long (86 > 80)

# Only split once and assume the second part is the image path
_, image_path = image_url.split("file://", 1)
if not _is_subpath(image_path, allowed_local_media_path):
raise ValueError(
"Invalid 'image_url': The file path must be a subpath of '--allowed-local-media-path'.")

Check failure on line 51 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/multimodal/utils.py:51:81: E501 Line too long (100 > 80)

Check failure on line 51 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/multimodal/utils.py:51:81: E501 Line too long (100 > 80)

image = Image.open(image_path)
image.load()
return image


def _load_image_from_data_url(image_url: str) -> Image.Image:
# Only split once and assume the second part is the base64 encoded image
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)


def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
def fetch_image(image_url: str, *, image_mode: str = "RGB", allowed_local_media_path: str = "") -> Image.Image:

Check failure on line 64 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/multimodal/utils.py:64:81: E501 Line too long (111 > 80)

Check failure on line 64 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/multimodal/utils.py:64:81: E501 Line too long (111 > 80)
"""
Load a PIL image from a HTTP or base64 data URL.
Expand All @@ -43,16 +74,18 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:

elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
"with either 'data:image', 'file://' or 'http'.")

return image.convert(image_mode)


async def async_fetch_image(image_url: str,
*,
image_mode: str = "RGB") -> Image.Image:
image_mode: str = "RGB", allowed_local_media_path: str = "") -> Image.Image:

Check failure on line 88 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/multimodal/utils.py:88:81: E501 Line too long (104 > 80)

Check failure on line 88 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/multimodal/utils.py:88:81: E501 Line too long (104 > 80)
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
Expand All @@ -65,9 +98,11 @@ async def async_fetch_image(image_url: str,

elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
"with either 'data:image', 'file://' or 'http'.")

return image.convert(image_mode)

Expand Down Expand Up @@ -126,8 +161,9 @@ def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
return {"audio": (audio, sr)}


def get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = fetch_image(image_url)
def get_and_parse_image(image_url: str, *, allowed_local_media_path: str = "") -> MultiModalDataDict:

Check failure on line 164 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/multimodal/utils.py:164:81: E501 Line too long (101 > 80)

Check failure on line 164 in vllm/multimodal/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (E501)

vllm/multimodal/utils.py:164:81: E501 Line too long (101 > 80)
image = fetch_image(
image_url, allowed_local_media_path=allowed_local_media_path)
return {"image": image}


Expand All @@ -136,8 +172,9 @@ async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
return {"audio": (audio, sr)}


async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await async_fetch_image(image_url)
async def async_get_and_parse_image(image_url: str, *, allowed_local_media_path: str = "") -> MultiModalDataDict:
image = await async_fetch_image(
image_url, allowed_local_media_path=allowed_local_media_path)
return {"image": image}


Expand Down

0 comments on commit 8c3572b

Please sign in to comment.