Skip to content

Commit

Permalink
[Frontend] Multi-Modality Support for Loading Local Image Files (vllm…
Browse files Browse the repository at this point in the history
…-project#9915)

Signed-off-by: chaunceyjiang <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
chaunceyjiang authored and mfournioux committed Nov 20, 2024
1 parent a461a66 commit 9d8dd00
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 14 deletions.
39 changes: 37 additions & 2 deletions tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import base64
import mimetypes
from tempfile import NamedTemporaryFile
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Dict, Tuple

import numpy as np
import pytest
from PIL import Image
from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.utils import (async_fetch_image, fetch_image,
Expand Down Expand Up @@ -84,6 +85,40 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
assert _image_equals(data_image_sync, data_image_async)


@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_local_files(image_url: str):
with TemporaryDirectory() as temp_dir:
origin_image = fetch_image(image_url)
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
quality=100,
icc_profile=origin_image.info.get('icc_profile'))

image_async = await async_fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)

image_sync = fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
# Check that the images are equal
assert not ImageChops.difference(image_sync, image_async).getbbox()

with pytest.raises(ValueError):
await async_fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
with pytest.raises(ValueError):
await async_fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")

with pytest.raises(ValueError):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
with pytest.raises(ValueError):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")


@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model)
Expand Down
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class ModelConfig:
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
Expand Down Expand Up @@ -134,6 +138,7 @@ def __init__(
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
Expand Down Expand Up @@ -164,6 +169,7 @@ def __init__(
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.allowed_local_media_path = allowed_local_media_path
self.seed = seed
self.revision = revision
self.code_revision = code_revision
Expand Down Expand Up @@ -1319,6 +1325,8 @@ def maybe_create_spec_config(
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
allowed_local_media_path=target_model_config.
allowed_local_media_path,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
Expand Down
9 changes: 9 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class EngineArgs:
tokenizer_mode: str = 'auto'
chat_template_text_format: str = 'string'
trust_remote_code: bool = False
allowed_local_media_path: str = ""
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: ConfigFormat = ConfigFormat.AUTO
Expand Down Expand Up @@ -269,6 +270,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
parser.add_argument(
'--allowed-local-media-path',
type=str,
help="Allowing API requests to read local images or videos"
"from directories specified by the server file system."
"This is a security risk."
"Should only be enabled in trusted environments")
parser.add_argument('--download-dir',
type=nullable_str,
default=EngineArgs.download_dir,
Expand Down Expand Up @@ -920,6 +928,7 @@ def create_model_config(self) -> ModelConfig:
tokenizer_mode=self.tokenizer_mode,
chat_template_text_format=self.chat_template_text_format,
trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
Expand Down
9 changes: 7 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:
self._tracker = tracker

def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url)
image = get_and_parse_image(image_url,
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)

placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
Expand All @@ -327,7 +329,10 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
self._tracker = tracker

def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(image_url)
image_coro = async_get_and_parse_image(
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)

placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class LLM:
from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images
or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
Expand Down Expand Up @@ -139,6 +143,7 @@ def __init__(
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
Expand Down Expand Up @@ -179,6 +184,7 @@ def __init__(
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
Expand Down
75 changes: 65 additions & 10 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import os
from functools import lru_cache
from io import BytesIO
from typing import Any, List, Optional, Tuple, TypeVar, Union
Expand All @@ -18,19 +19,60 @@
cached_get_tokenizer = lru_cache(get_tokenizer)


def _load_image_from_bytes(b: bytes):
def _load_image_from_bytes(b: bytes) -> Image.Image:
image = Image.open(BytesIO(b))
image.load()
return image


def _load_image_from_data_url(image_url: str):
def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool:
# Get the common path
common_path = os.path.commonpath([
os.path.abspath(image_path),
os.path.abspath(allowed_local_media_path)
])
# Check if the common path is the same as allowed_local_media_path
return common_path == os.path.abspath(allowed_local_media_path)


def _load_image_from_file(image_url: str,
allowed_local_media_path: str) -> Image.Image:
if not allowed_local_media_path:
raise ValueError("Invalid 'image_url': Cannot load local files without"
"'--allowed-local-media-path'.")
if allowed_local_media_path:
if not os.path.exists(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} does not exist.")
if not os.path.isdir(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} must be a directory.")

# Only split once and assume the second part is the image path
_, image_path = image_url.split("file://", 1)
if not _is_subpath(image_path, allowed_local_media_path):
raise ValueError(
f"Invalid 'image_url': The file path {image_path} must"
" be a subpath of '--allowed-local-media-path'"
f" '{allowed_local_media_path}'.")

image = Image.open(image_path)
image.load()
return image


def _load_image_from_data_url(image_url: str) -> Image.Image:
# Only split once and assume the second part is the base64 encoded image
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)


def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
def fetch_image(image_url: str,
*,
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
Expand All @@ -43,16 +85,19 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:

elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
"with either 'data:image', 'file://' or 'http'.")

return image.convert(image_mode)


async def async_fetch_image(image_url: str,
*,
image_mode: str = "RGB") -> Image.Image:
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
Expand All @@ -65,9 +110,11 @@ async def async_fetch_image(image_url: str,

elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
"with either 'data:image', 'file://' or 'http'.")

return image.convert(image_mode)

Expand Down Expand Up @@ -126,8 +173,12 @@ def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
return {"audio": (audio, sr)}


def get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = fetch_image(image_url)
def get_and_parse_image(
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = fetch_image(image_url,
allowed_local_media_path=allowed_local_media_path)
return {"image": image}


Expand All @@ -136,8 +187,12 @@ async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
return {"audio": (audio, sr)}


async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await async_fetch_image(image_url)
async def async_get_and_parse_image(
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = await async_fetch_image(
image_url, allowed_local_media_path=allowed_local_media_path)
return {"image": image}


Expand Down

0 comments on commit 9d8dd00

Please sign in to comment.