diff --git a/vllm/config.py b/vllm/config.py index 17e9b1c100498..9661df368f1d8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -55,6 +55,10 @@ class ModelConfig: "mistral" will always use the tokenizer from `mistral_common`. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images or + videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. dtype: Data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. @@ -134,6 +138,7 @@ def __init__( trust_remote_code: bool, dtype: Union[str, torch.dtype], seed: int, + allowed_local_media_path: str = None, revision: Optional[str] = None, code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None, @@ -164,6 +169,7 @@ def __init__( self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code + self.allowed_local_media_path = allowed_local_media_path self.seed = seed self.revision = revision self.code_revision = code_revision @@ -1319,6 +1325,8 @@ def maybe_create_spec_config( tokenizer=target_model_config.tokenizer, tokenizer_mode=target_model_config.tokenizer_mode, trust_remote_code=target_model_config.trust_remote_code, + allowed_local_media_path=target_model_config. + allowed_local_media_path, dtype=target_model_config.dtype, seed=target_model_config.seed, revision=draft_revision, @@ -1386,10 +1394,10 @@ def maybe_create_spec_config( ngram_prompt_lookup_max, ngram_prompt_lookup_min, draft_token_acceptance_method=draft_token_acceptance_method, - typical_acceptance_sampler_posterior_threshold=\ - typical_acceptance_sampler_posterior_threshold, - typical_acceptance_sampler_posterior_alpha=\ - typical_acceptance_sampler_posterior_alpha, + typical_acceptance_sampler_posterior_threshold= + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha= + typical_acceptance_sampler_posterior_alpha, disable_logprobs=disable_logprobs, disable_log_stats=disable_log_stats, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index da06ab186821e..b63645ad545bd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -92,6 +92,7 @@ class EngineArgs: tokenizer_mode: str = 'auto' chat_template_text_format: str = 'string' trust_remote_code: bool = False + allowed_local_media_path: str = "" download_dir: Optional[str] = None load_format: str = 'auto' config_format: ConfigFormat = ConfigFormat.AUTO @@ -269,6 +270,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') + parser.add_argument( + '--allowed-local-media-path', + type=str, + help="Allowing API requests to read local images or videos" + "from directories specified by the server file system." + "This is a security risk." + "Should only be enabled in trusted environments") parser.add_argument('--download-dir', type=nullable_str, default=EngineArgs.download_dir, @@ -920,6 +928,7 @@ def create_model_config(self) -> ModelConfig: tokenizer_mode=self.tokenizer_mode, chat_template_text_format=self.chat_template_text_format, trust_remote_code=self.trust_remote_code, + allowed_local_media_path=self.allowed_local_media_path, dtype=self.dtype, seed=self.seed, revision=self.revision, @@ -971,8 +980,8 @@ def create_engine_config(self) -> VllmConfig: f"'bitsandbytes' load format, but got {self.load_format}") if (self.load_format == "bitsandbytes" or - self.qlora_adapter_name_or_path is not None) and \ - self.quantization != "bitsandbytes": + self.qlora_adapter_name_or_path is not None) and \ + self.quantization != "bitsandbytes": raise ValueError( "BitsAndBytes load format and QLoRA adapter only support " f"'bitsandbytes' quantization, but got {self.quantization}") @@ -1059,10 +1068,9 @@ def create_engine_config(self) -> VllmConfig: target_parallel_config=parallel_config, target_dtype=self.dtype, speculative_model=self.speculative_model, - speculative_model_quantization = \ - self.speculative_model_quantization, - speculative_draft_tensor_parallel_size = \ - self.speculative_draft_tensor_parallel_size, + speculative_model_quantization=self.speculative_model_quantization, + speculative_draft_tensor_parallel_size=self. + speculative_draft_tensor_parallel_size, num_speculative_tokens=self.num_speculative_tokens, speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer, speculative_disable_by_batch_size=self. @@ -1072,8 +1080,7 @@ def create_engine_config(self) -> VllmConfig: disable_log_stats=self.disable_log_stats, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, - draft_token_acceptance_method=\ - self.spec_decoding_acceptance_method, + draft_token_acceptance_method=self.spec_decoding_acceptance_method, typical_acceptance_sampler_posterior_threshold=self. typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=self. @@ -1136,7 +1143,7 @@ def create_engine_config(self) -> VllmConfig: and self.max_cpu_loras > 0 else None) if self.enable_lora else None if self.qlora_adapter_name_or_path is not None and \ - self.qlora_adapter_name_or_path != "": + self.qlora_adapter_name_or_path != "": if self.model_loader_extra_config is None: self.model_loader_extra_config = {} self.model_loader_extra_config[ @@ -1147,7 +1154,7 @@ def create_engine_config(self) -> VllmConfig: prompt_adapter_config = PromptAdapterConfig( max_prompt_adapters=self.max_prompt_adapters, max_prompt_adapter_token=self.max_prompt_adapter_token) \ - if self.enable_prompt_adapter else None + if self.enable_prompt_adapter else None decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c9552977710d1..356d093a200f5 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -62,7 +62,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain image_url. This is supported by OpenAI API, although it is not documented. - + Example: { "image_url": "https://example.com/image.jpg" @@ -73,7 +73,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): """A simpler version of the param that only accepts a plain audio_url. - + Example: { "audio_url": "https://example.com/audio.mp3" @@ -307,7 +307,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None: self._tracker = tracker def parse_image(self, image_url: str) -> None: - image = get_and_parse_image(image_url) + image = get_and_parse_image(image_url, + allowed_local_media_path=self._tracker. + _model_config.allowed_local_media_path) placeholder = self._tracker.add("image", image) self._add_placeholder(placeholder) @@ -327,7 +329,10 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: self._tracker = tracker def parse_image(self, image_url: str) -> None: - image_coro = async_get_and_parse_image(image_url) + image_coro = async_get_and_parse_image( + image_url, + allowed_local_media_path=self._tracker._model_config. + allowed_local_media_path) placeholder = self._tracker.add("image", image_coro) self._add_placeholder(placeholder) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3d62cb3598477..78a57041b4089 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -58,6 +58,10 @@ class LLM: from the input. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images + or videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. dtype: The data type for the model weights and activations. Currently, @@ -139,6 +143,7 @@ def __init__( tokenizer_mode: str = "auto", skip_tokenizer_init: bool = False, trust_remote_code: bool = False, + allowed_local_media_path: str = "", tensor_parallel_size: int = 1, dtype: str = "auto", quantization: Optional[str] = None, @@ -179,6 +184,7 @@ def __init__( tokenizer_mode=tokenizer_mode, skip_tokenizer_init=skip_tokenizer_init, trust_remote_code=trust_remote_code, + allowed_local_media_path=allowed_local_media_path, tensor_parallel_size=tensor_parallel_size, dtype=dtype, quantization=quantization, @@ -475,7 +481,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float: logprob_obj.logprob) if token_id == tokenizer.eos_token_id and \ - not ignore_eos: + not ignore_eos: instance.completed.append(new_beam) else: instance_new_beams.append(new_beam) @@ -931,7 +937,8 @@ def _run_engine( # Calculate tokens only for RequestOutput assert output.prompt_token_ids is not None total_in_toks += len(output.prompt_token_ids) - in_spd = total_in_toks / pbar.format_dict["elapsed"] + in_spd = total_in_toks / \ + pbar.format_dict["elapsed"] total_out_toks += sum( len(stp.token_ids) for stp in output.outputs) out_spd = (total_out_toks / diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index c5ff552e06099..df6ae5ce19160 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,6 +1,7 @@ import base64 from functools import lru_cache from io import BytesIO +from os import path from typing import Any, List, Optional, Tuple, TypeVar, Union import numpy as np @@ -18,19 +19,54 @@ cached_get_tokenizer = lru_cache(get_tokenizer) -def _load_image_from_bytes(b: bytes): +def _load_image_from_bytes(b: bytes) -> Image.Image: image = Image.open(BytesIO(b)) image.load() return image -def _load_image_from_data_url(image_url: str): +def _is_subpath(image_path, allowed_local_media_path): + # Get the common path + common_path = path.commonpath([image_path, allowed_local_media_path]) + # Check if the common path is the same as allowed_local_media_path + return common_path == path.abspath(allowed_local_media_path) + + +def _load_image_from_file(image_url: str, + allowed_local_media_path: str) -> Image.Image: + if not allowed_local_media_path: + raise ValueError("Invalid 'image_url': Cannot load local files without" + "'--allowed-local-media-path'.") + if allowed_local_media_path: + if not path.exists(allowed_local_media_path): + raise ValueError( + "Invalid '--allowed-local-media-path': The path does not exist." + ) + if not path.isdir(allowed_local_media_path): + raise ValueError("Invalid '--allowed-local-media-path': " + "The path must be a directory.") + + # Only split once and assume the second part is the image path + _, image_path = image_url.split("file://", 1) + if not _is_subpath(image_path, allowed_local_media_path): + raise ValueError("Invalid 'image_url': The file path must be a" + " subpath of '--allowed-local-media-path'.") + + image = Image.open(image_path) + image.load() + return image + + +def _load_image_from_data_url(image_url: str) -> Image.Image: # Only split once and assume the second part is the base64 encoded image _, image_base64 = image_url.split(",", 1) return load_image_from_base64(image_base64) -def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: +def fetch_image(image_url: str, + *, + image_mode: str = "RGB", + allowed_local_media_path: str = "") -> Image.Image: """ Load a PIL image from a HTTP or base64 data URL. @@ -43,16 +79,19 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: elif image_url.startswith('data:image'): image = _load_image_from_data_url(image_url) + elif image_url.startswith('file://'): + image = _load_image_from_file(image_url, allowed_local_media_path) else: raise ValueError("Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image' or 'http'.") + "with either 'data:image', 'file://' or 'http'.") return image.convert(image_mode) async def async_fetch_image(image_url: str, *, - image_mode: str = "RGB") -> Image.Image: + image_mode: str = "RGB", + allowed_local_media_path: str = "") -> Image.Image: """ Asynchronously load a PIL image from a HTTP or base64 data URL. @@ -65,9 +104,11 @@ async def async_fetch_image(image_url: str, elif image_url.startswith('data:image'): image = _load_image_from_data_url(image_url) + elif image_url.startswith('file://'): + image = _load_image_from_file(image_url, allowed_local_media_path) else: raise ValueError("Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image' or 'http'.") + "with either 'data:image', 'file://' or 'http'.") return image.convert(image_mode) @@ -126,8 +167,12 @@ def get_and_parse_audio(audio_url: str) -> MultiModalDataDict: return {"audio": (audio, sr)} -def get_and_parse_image(image_url: str) -> MultiModalDataDict: - image = fetch_image(image_url) +def get_and_parse_image( + image_url: str, + *, + allowed_local_media_path: str = "") -> MultiModalDataDict: + image = fetch_image(image_url, + allowed_local_media_path=allowed_local_media_path) return {"image": image} @@ -136,8 +181,12 @@ async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: return {"audio": (audio, sr)} -async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: - image = await async_fetch_image(image_url) +async def async_get_and_parse_image( + image_url: str, + *, + allowed_local_media_path: str = "") -> MultiModalDataDict: + image = await async_fetch_image( + image_url, allowed_local_media_path=allowed_local_media_path) return {"image": image}