Skip to content

Commit

Permalink
[Misc] Abstract the logic for reading and writing media content (vllm…
Browse files Browse the repository at this point in the history
…-project#11527)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Dec 27, 2024
1 parent 2c9b8ea commit 7af553e
Show file tree
Hide file tree
Showing 10 changed files with 493 additions and 387 deletions.
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MockModelConfig:
hf_config = MockHFConfig()
logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
Expand Down
6 changes: 1 addition & 5 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional

import pytest
from PIL import Image

from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
Expand Down Expand Up @@ -91,10 +90,7 @@ def _assert_mm_data_is_image_input(
image_data = mm_data.get("image")
assert image_data is not None

if image_count == 1:
assert isinstance(image_data, Image.Image)
else:
assert isinstance(image_data, list) and len(image_data) == image_count
assert isinstance(image_data, list) and len(image_data) == image_count


def test_parse_chat_messages_single_image(
Expand Down
59 changes: 34 additions & 25 deletions tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer

from vllm.multimodal.utils import (async_fetch_image, fetch_image,
from vllm.multimodal.utils import (MediaConnector,
repeat_and_pad_placeholder_tokens)

# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
Expand All @@ -23,7 +23,12 @@

@pytest.fixture(scope="module")
def url_images() -> Dict[str, Image.Image]:
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
connector = MediaConnector()

return {
image_url: connector.fetch_image(image_url)
for image_url in TEST_IMAGE_URLS
}


def get_supported_suffixes() -> Tuple[str, ...]:
Expand All @@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str):
image_sync = fetch_image(image_url)
image_async = await async_fetch_image(image_url)
connector = MediaConnector()

image_sync = connector.fetch_image(image_url)
image_async = await connector.fetch_image_async(image_url)
assert _image_equals(image_sync, image_async)


Expand All @@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
image_url: str, suffix: str):
connector = MediaConnector()
url_image = url_images[image_url]

try:
Expand All @@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
base64_image = base64.b64encode(f.read()).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}"

data_image_sync = fetch_image(data_url)
data_image_sync = connector.fetch_image(data_url)
if _image_equals(url_image, Image.open(f)):
assert _image_equals(url_image, data_image_sync)
else:
pass # Lossy format; only check that image can be opened

data_image_async = await async_fetch_image(data_url)
data_image_async = await connector.fetch_image_async(data_url)
assert _image_equals(data_image_sync, data_image_async)


@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_local_files(image_url: str):
connector = MediaConnector()

with TemporaryDirectory() as temp_dir:
origin_image = fetch_image(image_url)
local_connector = MediaConnector(allowed_local_media_path=temp_dir)

origin_image = connector.fetch_image(image_url)
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
quality=100,
icc_profile=origin_image.info.get('icc_profile'))

image_async = await async_fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)

image_sync = fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
image_async = await local_connector.fetch_image_async(
f"file://{temp_dir}/{os.path.basename(image_url)}")
image_sync = local_connector.fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}")
# Check that the images are equal
assert not ImageChops.difference(image_sync, image_async).getbbox()

with pytest.raises(ValueError):
await async_fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
with pytest.raises(ValueError):
await async_fetch_image(
with pytest.raises(ValueError, match="must be a subpath"):
await local_connector.fetch_image_async(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(RuntimeError, match="Cannot load local files"):
await connector.fetch_image_async(
f"file://{temp_dir}/../{os.path.basename(image_url)}")

with pytest.raises(ValueError):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
with pytest.raises(ValueError):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(ValueError, match="must be a subpath"):
local_connector.fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(RuntimeError, match="Cannot load local files"):
connector.fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")


@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
Expand Down
6 changes: 2 additions & 4 deletions vllm/assets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@ class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]

@property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]:
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
y, sr = librosa.load(audio_path, sr=None)
assert isinstance(sr, int)
return y, sr
return librosa.load(audio_path, sr=None)

@property
def url(self) -> str:
Expand Down
Loading

0 comments on commit 7af553e

Please sign in to comment.