From 242ea246c015cd268310cc75bccf53b78ead4fff Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Fri, 13 Dec 2024 16:20:29 +0800 Subject: [PATCH 01/16] [Feature] Add load generation config from model Signed-off-by: liuyanyi --- ...nference_with_default_generation_config.py | 30 +++++++ vllm/config.py | 90 ++++++++++++------- vllm/engine/arg_utils.py | 11 +++ vllm/engine/llm_engine.py | 23 +---- vllm/entrypoints/llm.py | 28 +++++- vllm/entrypoints/openai/api_server.py | 44 ++++++++- vllm/v1/engine/processor.py | 20 +---- 7 files changed, 173 insertions(+), 73 deletions(-) create mode 100644 examples/offline_inference_with_default_generation_config.py diff --git a/examples/offline_inference_with_default_generation_config.py b/examples/offline_inference_with_default_generation_config.py new file mode 100644 index 0000000000000..346bb80b1e23f --- /dev/null +++ b/examples/offline_inference_with_default_generation_config.py @@ -0,0 +1,30 @@ +from vllm import LLM + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create an LLM with built-in default generation config. +# The generation config is set to None by default to keep +# the behavior consistent with the previous version. +# If you want to use the default generation config from the model, +# you should set the generation_config to "auto". +llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", generation_config="auto") + +# Load the default sampling parameters from the model. +sampling_params = llm.get_default_sampling_params() +# Modify the sampling parameters if needed. +sampling_params.temperature = 0.5 + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/config.py b/vllm/config.py index 08a7b607630af..23d0a47cf204f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -26,7 +26,8 @@ from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, - get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) + get_sentence_transformer_tokenizer_config, is_encoder_decoder, + try_get_generation_config, uses_mrope) from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, get_cpu_memory, print_warning_once, random_uuid, resolve_obj_by_qualname) @@ -156,41 +157,42 @@ class ModelConfig: can not be gathered from the vllm arguments. override_pooler_config: Initialize non default pooling config or override default pooling config for the pooling model. + generation_config: Configuration parameter file for generation. """ - def __init__( - self, - model: str, - task: Union[TaskOption, Literal["draft"]], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - mm_cache_preprocessor: bool = False, - override_neuron_config: Optional[Dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None) -> None: + def __init__(self, + model: str, + task: Union[TaskOption, Literal["draft"]], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + mm_cache_preprocessor: bool = False, + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + generation_config: Optional[str] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -317,6 +319,8 @@ def __init__( self.pooler_config = self._init_pooler_config(override_pooler_config) + self.generation_config = generation_config + self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -782,6 +786,24 @@ def get_multimodal_config(self) -> "MultiModalConfig": return self.multimodal_config + def try_get_generation_config(self) -> Dict[str, Any]: + if self.generation_config is None or self.generation_config == "auto": + config = try_get_generation_config( + self.model, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + ) + else: + config = try_get_generation_config( + self.generation_config, + trust_remote_code=self.trust_remote_code, + ) + + if config is None: + return {} + + return config.to_diff_dict() + @property def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0c28fe7032728..f6d7503fdafb9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -198,6 +198,8 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None + generation_config: Optional[str] = None + def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model @@ -932,6 +934,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default="auto", help='The worker class to use for distributed execution.') + parser.add_argument( + "--generation-config", + type=nullable_str, + default=None, + help="The folder path to the generation config. If set to " + "'auto', the generation config will be automatically loaded " + "from the model's folder.", + ) + return parser @classmethod diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9be30c635cb2c..97bdfdf3a947a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,8 +5,8 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, - Iterable, List, Mapping, NamedTuple, Optional) +from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable, + List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence from typing import Set, Type, Union, cast, overload @@ -53,7 +53,6 @@ SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) -from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import ( @@ -66,20 +65,6 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 - -def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: - config = try_get_generation_config( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.revision, - ) - - if config is None: - return {} - - return config.to_diff_dict() - - _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput) @@ -275,8 +260,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: return tokenizer_group.get_lora_tokenizer(sequence.lora_request) self.seq_counter = Counter() - self.generation_config_fields = _load_generation_config_dict( - self.model_config) + self.generation_config_fields = ( + self.model_config.try_get_generation_config()) self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0bec978c4869c..0b10ec6734848 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -230,6 +230,11 @@ def __init__( self.request_counter = Counter() + # Generation config fields + self.overwrite_default_sampling_params = (engine_args.generation_config + is not None) + self.generation_config_fields = self.llm_engine.generation_config_fields + @staticmethod def get_engine_class() -> Type[LLMEngine]: if envs.VLLM_USE_V1: @@ -252,6 +257,27 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: else: tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) + def get_default_sampling_params(self) -> SamplingParams: + if self.overwrite_default_sampling_params: + available_params = [ + "repetition_penalty", + "temperature", + "top_k", + "top_p", + "min_p", + ] + default_param_dict = { + param: self.generation_config_fields.get(param, None) + for param in available_params + } + # Filter the None values + default_param_dict = { + k: v + for k, v in default_param_dict.items() if v is not None + } + return SamplingParams.from_optional(**default_param_dict) + return SamplingParams() + @overload def generate( self, @@ -417,7 +443,7 @@ def generate( if sampling_params is None: # Use default sampling params. - sampling_params = SamplingParams() + sampling_params = self.get_default_sampling_params() self._validate_and_add_requests( prompts=parsed_prompts, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2e27224b41864..a65c457f4ecb3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -352,7 +352,8 @@ async def show_version(): return JSONResponse(content=ver) -@router.post("/v1/chat/completions") +# Lazy include chat completion routes to make sure new sampling +# parameters are included async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): handler = chat(raw_request) @@ -372,7 +373,8 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/completions") +# Lazy include chat completion routes to make sure new sampling +# parameters are included async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) if handler is None: @@ -568,6 +570,35 @@ def init_app_state( resolved_chat_template = load_chat_template(args.chat_template) logger.info("Using supplied chat template:\n%s", resolved_chat_template) + if args.generation_config: + generation_config_fields = model_config.try_get_generation_config() + available_params = [ + "repetition_penalty", + "temperature", + "top_k", + "top_p", + "min_p", + ] + + if any(p in generation_config_fields for p in available_params): + overwrite_config = { + p: generation_config_fields.get(p, None) + for p in available_params + } + logger.info("Overwriting generation config with: %s", + overwrite_config) + # Modify the ChatCompletionRequest to include the generation config + for k, v in overwrite_config.items(): + if v is not None: + ChatCompletionRequest.model_fields[k].default = v + CompletionRequest.model_fields[k].default = v + + # Rebuild the models to include the new fields + ChatCompletionRequest.model_rebuild(force=True) + CompletionRequest.model_rebuild(force=True) + else: + logger.warning("No generation config found.") + state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, @@ -661,6 +692,15 @@ def signal_handler(*_) -> None: model_config = await engine_client.get_model_config() init_app_state(engine_client, model_config, app.state, args) + # Lazy include chat completion routes to make sure new sampling + # parameters are included + app.add_api_route("/v1/chat/completions", + create_chat_completion, + methods=["POST"]) + app.add_api_route("/v1/completions", + create_completion, + methods=["POST"]) + shutdown_task = await serve_http( app, host=args.host, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 903996bad3726..5ce9952243d0c 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, Mapping, Optional, Tuple, Union +from typing import Mapping, Optional, Tuple, Union from vllm.config import LoRAConfig, ModelConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, @@ -12,7 +12,6 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient @@ -33,8 +32,8 @@ def __init__( self.lora_config = lora_config self.tokenizer = tokenizer - self.generation_config_fields = _load_generation_config_dict( - model_config) + self.generation_config_fields = model_config.try_get_generation_config( + ) self.input_preprocessor = InputPreprocessor(model_config, self.tokenizer, mm_registry) @@ -179,16 +178,3 @@ def _validate_model_inputs(self, inputs: ProcessorInputs): # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens - - -def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: - config = try_get_generation_config( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.revision, - ) - - if config is None: - return {} - - return config.to_diff_dict() From bdedf17dbb324cee34024c5ff3914f8f325ed887 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Fri, 13 Dec 2024 17:19:37 +0800 Subject: [PATCH 02/16] Fix v1 engine Signed-off-by: liuyanyi --- vllm/v1/engine/llm_engine.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 1b3a9f12d009e..fc133567fe4be 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -76,6 +76,10 @@ def __init__( asyncio_mode=False, ) + self.generation_config_fields = ( + self.model_config.try_get_generation_config()) + + @classmethod def from_engine_args( cls, From c24141786d5346d822cd58a6d2aa5595224ca8e0 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Fri, 13 Dec 2024 17:24:33 +0800 Subject: [PATCH 03/16] Fix lint Signed-off-by: liuyanyi --- vllm/v1/engine/llm_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index fc133567fe4be..90041c576f216 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -79,7 +79,6 @@ def __init__( self.generation_config_fields = ( self.model_config.try_get_generation_config()) - @classmethod def from_engine_args( cls, From 4ee9aacf20b4f93e46cf15b420bf64585fdeb905 Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Fri, 13 Dec 2024 15:41:19 +0000 Subject: [PATCH 04/16] fix Signed-off-by: Yanyi Liu --- vllm/engine/arg_utils.py | 9 +++++---- vllm/entrypoints/llm.py | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f6d7503fdafb9..310b826aee7f1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -938,10 +938,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--generation-config", type=nullable_str, default=None, - help="The folder path to the generation config. If set to " - "'auto', the generation config will be automatically loaded " - "from the model's folder.", - ) + help="The folder path to the generation config. " + "Defaults to None, will use the default generation config in vLLM. " + "If set to 'auto', the generation config will be automatically loaded " + "from model. If set to a folder path, the generation config " + "will be loaded from the specified folder path.") return parser diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0b10ec6734848..427e7982056fe 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -230,10 +230,9 @@ def __init__( self.request_counter = Counter() - # Generation config fields + # Overwrite default sampling when generation config is set self.overwrite_default_sampling_params = (engine_args.generation_config is not None) - self.generation_config_fields = self.llm_engine.generation_config_fields @staticmethod def get_engine_class() -> Type[LLMEngine]: @@ -259,6 +258,7 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: def get_default_sampling_params(self) -> SamplingParams: if self.overwrite_default_sampling_params: + generation_config_fields = self.llm_engine.generation_config_fields available_params = [ "repetition_penalty", "temperature", @@ -267,7 +267,7 @@ def get_default_sampling_params(self) -> SamplingParams: "min_p", ] default_param_dict = { - param: self.generation_config_fields.get(param, None) + param: generation_config_fields.get(param, None) for param in available_params } # Filter the None values From 7f2d07e6f50d9db52240a35ff95792369e3b16a0 Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Fri, 13 Dec 2024 16:05:15 +0000 Subject: [PATCH 05/16] move Signed-off-by: Yanyi Liu --- vllm/entrypoints/llm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 427e7982056fe..8bdbc92431d42 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -230,10 +230,6 @@ def __init__( self.request_counter = Counter() - # Overwrite default sampling when generation config is set - self.overwrite_default_sampling_params = (engine_args.generation_config - is not None) - @staticmethod def get_engine_class() -> Type[LLMEngine]: if envs.VLLM_USE_V1: @@ -257,7 +253,11 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) def get_default_sampling_params(self) -> SamplingParams: - if self.overwrite_default_sampling_params: + # Overwrite default sampling when generation config is set + overwrite_default_sampling_params = ( + self.llm_engine.model_config.generation_config is not None + ) + if overwrite_default_sampling_params: generation_config_fields = self.llm_engine.generation_config_fields available_params = [ "repetition_penalty", From ed0b9b046af6f231596a0fe535cac63d274b06cb Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Fri, 13 Dec 2024 17:03:36 +0000 Subject: [PATCH 06/16] Unify diff sample param into model config Signed-off-by: Yanyi Liu --- vllm/config.py | 32 +++++++++++++++++++ vllm/engine/arg_utils.py | 4 +-- vllm/entrypoints/llm.py | 27 +++------------- vllm/entrypoints/openai/api_server.py | 29 ----------------- vllm/entrypoints/openai/serving_chat.py | 10 ++++++ vllm/entrypoints/openai/serving_completion.py | 11 +++++++ 6 files changed, 59 insertions(+), 54 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 23d0a47cf204f..51941611af7e3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -804,6 +804,38 @@ def try_get_generation_config(self) -> Dict[str, Any]: return config.to_diff_dict() + def get_diff_sampling_param(self) -> Dict[str, Any]: + """ + This method returns a dictionary containing the parameters + that differ from the default sampling parameters, but only + if `generation_config` is set. If `generation_config` is not + set, an empty dictionary is returned. + + Returns: + Dict[str, Any]: A dictionary with the differing sampling + parameters if `generation_config` is set, otherwise an + empty dictionary. + """ + if self.generation_config is None: + # When generation_config is not set + return {} + config = self.try_get_generation_config() + available_params = [ + "repetition_penalty", + "temperature", + "top_k", + "top_p", + "min_p", + ] + if any(p in config for p in available_params): + diff_sampling_param = { + p: config.get(p, None) + for p in available_params + } + else: + diff_sampling_param = {} + return diff_sampling_param + @property def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 310b826aee7f1..885ca068d0418 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -940,8 +940,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=None, help="The folder path to the generation config. " "Defaults to None, will use the default generation config in vLLM. " - "If set to 'auto', the generation config will be automatically loaded " - "from model. If set to a folder path, the generation config " + "If set to 'auto', the generation config will be automatically " + "loaded from model. If set to a folder path, the generation config " "will be loaded from the specified folder path.") return parser diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8bdbc92431d42..fce4b5ac482cc 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -253,29 +253,10 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) def get_default_sampling_params(self) -> SamplingParams: - # Overwrite default sampling when generation config is set - overwrite_default_sampling_params = ( - self.llm_engine.model_config.generation_config is not None - ) - if overwrite_default_sampling_params: - generation_config_fields = self.llm_engine.generation_config_fields - available_params = [ - "repetition_penalty", - "temperature", - "top_k", - "top_p", - "min_p", - ] - default_param_dict = { - param: generation_config_fields.get(param, None) - for param in available_params - } - # Filter the None values - default_param_dict = { - k: v - for k, v in default_param_dict.items() if v is not None - } - return SamplingParams.from_optional(**default_param_dict) + diff_sampling_param = ( + self.llm_engine.model_config.get_diff_sampling_param()) + if diff_sampling_param: + return SamplingParams.from_optional(**diff_sampling_param) return SamplingParams() @overload diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a65c457f4ecb3..3a9ad5f4deee5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -570,35 +570,6 @@ def init_app_state( resolved_chat_template = load_chat_template(args.chat_template) logger.info("Using supplied chat template:\n%s", resolved_chat_template) - if args.generation_config: - generation_config_fields = model_config.try_get_generation_config() - available_params = [ - "repetition_penalty", - "temperature", - "top_k", - "top_p", - "min_p", - ] - - if any(p in generation_config_fields for p in available_params): - overwrite_config = { - p: generation_config_fields.get(p, None) - for p in available_params - } - logger.info("Overwriting generation config with: %s", - overwrite_config) - # Modify the ChatCompletionRequest to include the generation config - for k, v in overwrite_config.items(): - if v is not None: - ChatCompletionRequest.model_fields[k].default = v - CompletionRequest.model_fields[k].default = v - - # Rebuild the models to include the new fields - ChatCompletionRequest.model_rebuild(force=True) - CompletionRequest.model_rebuild(force=True) - else: - logger.warning("No generation config found.") - state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0738210e27cb6..b7396c42ed03a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -92,6 +92,7 @@ def __init__( "been registered") from e self.enable_prompt_tokens_details = enable_prompt_tokens_details + self._try_overwrite_sampling_param() async def create_chat_completion( self, @@ -856,3 +857,12 @@ def _should_check_for_unstreamed_tool_arg_tokens( and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None ) + + def _try_overwrite_sampling_param(self): + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info("Overwriting default chat sampling param with: %s", + diff_sampling_param) + for k, v in diff_sampling_param.items(): + ChatCompletionRequest.model_fields[k].default = v + ChatCompletionRequest.model_rebuild(force=True) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index ee97d35f2b087..0402a866da699 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -55,6 +55,7 @@ def __init__( prompt_adapters=prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) + self._try_overwrite_sampling_param() async def create_completion( self, @@ -541,3 +542,13 @@ def _create_completion_logprobs( tokens=out_tokens, top_logprobs=out_top_logprobs, ) + + def _try_overwrite_sampling_param(self): + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info( + "Overwriting default completion sampling param with: %s", + diff_sampling_param) + for k, v in diff_sampling_param.items(): + CompletionRequest.model_fields[k].default = v + CompletionRequest.model_rebuild(force=True) From ef436b7aef88d669d2f68ba71dc044598907d00b Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Fri, 13 Dec 2024 17:06:07 +0000 Subject: [PATCH 07/16] remove useless generation_config_fields Signed-off-by: Yanyi Liu --- vllm/v1/engine/llm_engine.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 90041c576f216..1b3a9f12d009e 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -76,9 +76,6 @@ def __init__( asyncio_mode=False, ) - self.generation_config_fields = ( - self.model_config.try_get_generation_config()) - @classmethod def from_engine_args( cls, From cde9a97b9eaa509345ada21bddda2513e418fe4a Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Sat, 14 Dec 2024 03:20:44 +0000 Subject: [PATCH 08/16] fix test for chat Signed-off-by: Yanyi Liu --- tests/entrypoints/openai/test_serving_chat.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 93660e6118ca8..ffe8a8bd822b0 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -31,6 +31,9 @@ class MockModelConfig: multimodal_config = MultiModalConfig() hf_config = MockHFConfig() + def get_diff_sampling_param(self): + return {} + @dataclass class MockEngine: @@ -93,3 +96,38 @@ def test_serving_chat_should_set_correct_max_tokens(): asyncio.run(serving_chat.create_chat_completion(req)) assert mock_engine.generate.call_args.args[1].max_tokens == 10 + + +def test_serving_chat_could_load_correct_generation_config(): + + mock_model_config = MagicMock(spec=MockModelConfig) + mock_model_config.get_diff_sampling_param.return_value = { + "temperature": 0.5, + "repetition_penalty": 1.05 + } + + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + # Initialize the serving chat + OpenAIServingChat(mock_engine, + mock_model_config, + BASE_MODEL_PATHS, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + lora_modules=None, + prompt_adapters=None, + request_logger=None) + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + guided_decoding_backend="outlines", + ) + + assert req.temperature == 0.5 + assert req.repetition_penalty == 1.05 From c0c806ea358a884ab6ea4e9ffb06ef67571f80ee Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Sun, 15 Dec 2024 04:45:38 +0000 Subject: [PATCH 09/16] remove tricky pydantic process Signed-off-by: Yanyi Liu --- tests/entrypoints/openai/test_serving_chat.py | 32 +++--- vllm/entrypoints/openai/protocol.py | 98 +++++++++++++------ vllm/entrypoints/openai/serving_chat.py | 21 ++-- vllm/entrypoints/openai/serving_completion.py | 23 ++--- 4 files changed, 107 insertions(+), 67 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index ffe8a8bd822b0..0a4fb2d39e21c 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -30,9 +30,10 @@ class MockModelConfig: tokenizer_revision = None multimodal_config = MultiModalConfig() hf_config = MockHFConfig() + diff_sampling_param = None def get_diff_sampling_param(self): - return {} + return self.diff_sampling_param or {} @dataclass @@ -100,8 +101,8 @@ def test_serving_chat_should_set_correct_max_tokens(): def test_serving_chat_could_load_correct_generation_config(): - mock_model_config = MagicMock(spec=MockModelConfig) - mock_model_config.get_diff_sampling_param.return_value = { + mock_model_config = MockModelConfig() + mock_model_config.diff_sampling_param = { "temperature": 0.5, "repetition_penalty": 1.05 } @@ -111,15 +112,15 @@ def test_serving_chat_could_load_correct_generation_config(): mock_engine.errored = False # Initialize the serving chat - OpenAIServingChat(mock_engine, - mock_model_config, - BASE_MODEL_PATHS, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - lora_modules=None, - prompt_adapters=None, - request_logger=None) + serving_chat = OpenAIServingChat(mock_engine, + mock_model_config, + BASE_MODEL_PATHS, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + lora_modules=None, + prompt_adapters=None, + request_logger=None) req = ChatCompletionRequest( model=MODEL_NAME, messages=[{ @@ -129,5 +130,8 @@ def test_serving_chat_could_load_correct_generation_config(): guided_decoding_backend="outlines", ) - assert req.temperature == 0.5 - assert req.repetition_penalty == 1.05 + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].temperature == 0.5 + assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ee94a9413f098..0f7e1909de839 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -170,8 +170,8 @@ class ChatCompletionRequest(OpenAIBaseModel): stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None - temperature: Optional[float] = 0.7 - top_p: Optional[float] = 1.0 + temperature: Optional[float] = None + top_p: Optional[float] = None tools: Optional[List[ChatCompletionToolsParam]] = None tool_choice: Optional[Union[Literal["none"], Literal["auto"], ChatCompletionNamedToolChoiceParam]] = "none" @@ -183,9 +183,9 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: begin-chat-completion-sampling-params best_of: Optional[int] = None use_beam_search: bool = False - top_k: int = -1 - min_p: float = 0.0 - repetition_penalty: float = 1.0 + top_k: Optional[int] = None + min_p: Optional[float] = None + repetition_penalty: Optional[float] = None length_penalty: float = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) include_stop_str_in_output: bool = False @@ -296,15 +296,21 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_beam_search_params(self, - default_max_tokens: int) -> BeamSearchParams: + def to_beam_search_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None + ) -> BeamSearchParams: # TODO(#9845): remove max_tokens when field is removed from OpenAI API max_tokens = self.max_completion_tokens or self.max_tokens if max_tokens is None: max_tokens = default_max_tokens + if default_sampling_params is None: + default_sampling_params = {} n = self.n if self.n is not None else 1 - temperature = self.temperature if self.temperature is not None else 0.0 + temperature = self.temperature or default_sampling_params.get( + "temperature", 0.0) return BeamSearchParams( beam_width=n, @@ -314,12 +320,27 @@ def to_beam_search_params(self, length_penalty=self.length_penalty, include_stop_str_in_output=self.include_stop_str_in_output) - def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: # TODO(#9845): remove max_tokens when field is removed from OpenAI API max_tokens = self.max_completion_tokens or self.max_tokens if max_tokens is None: max_tokens = default_max_tokens + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + repetition_penalty = (self.repetition_penalty + or default_sampling_params.get( + "repetition_penalty", 1.0)) + temperature = self.temperature or default_sampling_params.get( + "temperature", 1.0) + top_p = self.top_p or default_sampling_params.get("top_p", 1.0) + top_k = self.top_k or default_sampling_params.get("top_k", -1) + min_p = self.min_p or default_sampling_params.get("min_p", 0.0) + prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs @@ -349,11 +370,11 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: best_of=self.best_of, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - min_p=self.min_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, @@ -528,15 +549,15 @@ class CompletionRequest(OpenAIBaseModel): stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None - temperature: Optional[float] = 1.0 - top_p: Optional[float] = 1.0 + temperature: Optional[float] = None + top_p: Optional[float] = None user: Optional[str] = None # doc: begin-completion-sampling-params use_beam_search: bool = False - top_k: int = -1 - min_p: float = 0.0 - repetition_penalty: float = 1.0 + top_k: Optional[int] = None + min_p: Optional[float] = None + repetition_penalty: Optional[float] = None length_penalty: float = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) include_stop_str_in_output: bool = False @@ -602,14 +623,20 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_beam_search_params(self, - default_max_tokens: int) -> BeamSearchParams: + def to_beam_search_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None + ) -> BeamSearchParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens + if default_sampling_params is None: + default_sampling_params = {} n = self.n if self.n is not None else 1 - temperature = self.temperature if self.temperature is not None else 0.0 + temperature = self.temperature or default_sampling_params.get( + "temperature", 0.0) return BeamSearchParams( beam_width=n, @@ -619,11 +646,26 @@ def to_beam_search_params(self, length_penalty=self.length_penalty, include_stop_str_in_output=self.include_stop_str_in_output) - def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + repetition_penalty = (self.repetition_penalty + or default_sampling_params.get( + "repetition_penalty", 1.0)) + temperature = self.temperature or default_sampling_params.get( + "temperature", 1.0) + top_p = self.top_p or default_sampling_params.get("top_p", 1.0) + top_k = self.top_k or default_sampling_params.get("top_k", -1) + min_p = self.min_p or default_sampling_params.get("min_p", 0.0) + prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: prompt_logprobs = self.logprobs @@ -649,11 +691,11 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: best_of=self.best_of, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - min_p=self.min_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b7396c42ed03a..816aafac20731 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -92,7 +92,10 @@ def __init__( "been registered") from e self.enable_prompt_tokens_details = enable_prompt_tokens_details - self._try_overwrite_sampling_param() + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info("Overwriting default chat sampling param with: %s", + diff_sampling_param) async def create_chat_completion( self, @@ -191,12 +194,15 @@ async def create_chat_completion( sampling_params: Union[SamplingParams, BeamSearchParams] default_max_tokens = self.max_model_len - len( engine_prompt["prompt_token_ids"]) + # Build default sampling params + default_sampling_params = ( + self.model_config.get_diff_sampling_param()) if request.use_beam_search: sampling_params = request.to_beam_search_params( - default_max_tokens) + default_max_tokens, default_sampling_params) else: sampling_params = request.to_sampling_params( - default_max_tokens) + default_max_tokens, default_sampling_params) self._log_inputs(request_id, request_prompts[i], @@ -857,12 +863,3 @@ def _should_check_for_unstreamed_tool_arg_tokens( and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None ) - - def _try_overwrite_sampling_param(self): - diff_sampling_param = self.model_config.get_diff_sampling_param() - if diff_sampling_param: - logger.info("Overwriting default chat sampling param with: %s", - diff_sampling_param) - for k, v in diff_sampling_param.items(): - ChatCompletionRequest.model_fields[k].default = v - ChatCompletionRequest.model_rebuild(force=True) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 0402a866da699..dd1d1ea3d3d33 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -55,7 +55,11 @@ def __init__( prompt_adapters=prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) - self._try_overwrite_sampling_param() + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info( + "Overwriting default completion sampling param with: %s", + diff_sampling_param) async def create_completion( self, @@ -120,12 +124,15 @@ async def create_completion( sampling_params: Union[SamplingParams, BeamSearchParams] default_max_tokens = self.max_model_len - len( engine_prompt["prompt_token_ids"]) + # Build default sampling params + default_sampling_params = ( + self.model_config.get_diff_sampling_param()) if request.use_beam_search: sampling_params = request.to_beam_search_params( - default_max_tokens) + default_max_tokens, default_sampling_params) else: sampling_params = request.to_sampling_params( - default_max_tokens) + default_max_tokens, default_sampling_params) request_id_item = f"{request_id}-{i}" @@ -542,13 +549,3 @@ def _create_completion_logprobs( tokens=out_tokens, top_logprobs=out_top_logprobs, ) - - def _try_overwrite_sampling_param(self): - diff_sampling_param = self.model_config.get_diff_sampling_param() - if diff_sampling_param: - logger.info( - "Overwriting default completion sampling param with: %s", - diff_sampling_param) - for k, v in diff_sampling_param.items(): - CompletionRequest.model_fields[k].default = v - CompletionRequest.model_rebuild(force=True) From 41b1786a1bf48b6759dac183816b7f47217410c5 Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Sun, 15 Dec 2024 04:47:52 +0000 Subject: [PATCH 10/16] reverse Signed-off-by: Yanyi Liu --- vllm/entrypoints/openai/api_server.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3a9ad5f4deee5..2e27224b41864 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -352,8 +352,7 @@ async def show_version(): return JSONResponse(content=ver) -# Lazy include chat completion routes to make sure new sampling -# parameters are included +@router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): handler = chat(raw_request) @@ -373,8 +372,7 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -# Lazy include chat completion routes to make sure new sampling -# parameters are included +@router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) if handler is None: @@ -663,15 +661,6 @@ def signal_handler(*_) -> None: model_config = await engine_client.get_model_config() init_app_state(engine_client, model_config, app.state, args) - # Lazy include chat completion routes to make sure new sampling - # parameters are included - app.add_api_route("/v1/chat/completions", - create_chat_completion, - methods=["POST"]) - app.add_api_route("/v1/completions", - create_completion, - methods=["POST"]) - shutdown_task = await serve_http( app, host=args.host, From b24cc51b551cc46c5e3a570f9d73bbc42959442c Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Sun, 15 Dec 2024 04:50:31 +0000 Subject: [PATCH 11/16] fix default value Signed-off-by: Yanyi Liu --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0f7e1909de839..6c30681836914 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -336,7 +336,7 @@ def to_sampling_params( or default_sampling_params.get( "repetition_penalty", 1.0)) temperature = self.temperature or default_sampling_params.get( - "temperature", 1.0) + "temperature", 0.7) top_p = self.top_p or default_sampling_params.get("top_p", 1.0) top_k = self.top_k or default_sampling_params.get("top_k", -1) min_p = self.min_p or default_sampling_params.get("min_p", 0.0) From cffde045ccf13f0cd7fe298a23064c29ee2bc2e6 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Sun, 15 Dec 2024 21:14:01 +0800 Subject: [PATCH 12/16] fix wrong temp Signed-off-by: liuyanyi --- tests/entrypoints/openai/test_serving_chat.py | 12 +++++- vllm/config.py | 4 +- vllm/engine/arg_utils.py | 3 +- vllm/entrypoints/openai/protocol.py | 38 +++++++++++-------- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 71a061863774f..ae88cbf128ea3 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,6 +1,7 @@ import asyncio from contextlib import suppress from dataclasses import dataclass +from typing import Optional from unittest.mock import MagicMock from vllm.config import MultiModalConfig @@ -31,7 +32,7 @@ class MockModelConfig: multimodal_config = MultiModalConfig() hf_config = MockHFConfig() logits_processor_pattern = None - diff_sampling_param = None + diff_sampling_param: Optional[dict] = None def get_diff_sampling_param(self): return self.diff_sampling_param or {} @@ -136,3 +137,12 @@ def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].temperature == 0.5 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + + # Test the param when user set it + req.temperature = 0.1 + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].temperature == 0.1 + assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 diff --git a/vllm/config.py b/vllm/config.py index 2742885a799aa..e90b66297f4e2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -835,8 +835,8 @@ def get_diff_sampling_param(self) -> Dict[str, Any]: ] if any(p in config for p in available_params): diff_sampling_param = { - p: config.get(p, None) - for p in available_params + p: config.get(p) + for p in available_params if config.get(p) is not None } else: diff_sampling_param = {} diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d346275dd4f47..933ecc335d290 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -996,7 +996,8 @@ def create_model_config(self) -> ModelConfig: mm_cache_preprocessor=self.mm_cache_preprocessor, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, - logits_processor_pattern=self.logits_processor_pattern) + logits_processor_pattern=self.logits_processor_pattern, + generation_config=self.generation_config) def create_load_config(self) -> LoadConfig: return LoadConfig( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 920d666709c56..0b34cae12388d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -385,14 +385,17 @@ def to_sampling_params( if default_sampling_params is None: default_sampling_params = {} # Default parameters - repetition_penalty = (self.repetition_penalty - or default_sampling_params.get( - "repetition_penalty", 1.0)) - temperature = self.temperature or default_sampling_params.get( - "temperature", 0.7) - top_p = self.top_p or default_sampling_params.get("top_p", 1.0) - top_k = self.top_k or default_sampling_params.get("top_k", -1) - min_p = self.min_p or default_sampling_params.get("min_p", 0.0) + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = (default_sampling_params.get( + "repetition_penalty", 1.0)) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 0.7) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get("top_p", 1.0) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get("top_k", -1) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get("min_p", 0.0) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: @@ -724,14 +727,17 @@ def to_sampling_params( if default_sampling_params is None: default_sampling_params = {} # Default parameters - repetition_penalty = (self.repetition_penalty - or default_sampling_params.get( - "repetition_penalty", 1.0)) - temperature = self.temperature or default_sampling_params.get( - "temperature", 1.0) - top_p = self.top_p or default_sampling_params.get("top_p", 1.0) - top_k = self.top_k or default_sampling_params.get("top_k", -1) - min_p = self.min_p or default_sampling_params.get("min_p", 0.0) + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = (default_sampling_params.get( + "repetition_penalty", 1.0)) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 1.0) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get("top_p", 1.0) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get("top_k", -1) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get("min_p", 0.0) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: From 08ad1b755047cb803ecaca5b47b0a2b5ebc08910 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Mon, 16 Dec 2024 11:10:35 +0800 Subject: [PATCH 13/16] Fix, add test for temp 0.0 Signed-off-by: liuyanyi --- tests/entrypoints/openai/test_serving_chat.py | 9 +++++++++ vllm/entrypoints/openai/protocol.py | 9 +++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index ae88cbf128ea3..51b255bb2a6db 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -146,3 +146,12 @@ def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].temperature == 0.1 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + + # Test When temperature==0.0 + req.temperature = 0.0 + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].temperature == 0.0 + assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0b34cae12388d..6354fb6cb6895 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -361,8 +361,9 @@ def to_beam_search_params( if default_sampling_params is None: default_sampling_params = {} n = self.n if self.n is not None else 1 - temperature = self.temperature or default_sampling_params.get( - "temperature", 0.0) + + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 0.0) return BeamSearchParams( beam_width=n, @@ -389,7 +390,7 @@ def to_sampling_params( repetition_penalty = (default_sampling_params.get( "repetition_penalty", 1.0)) if (temperature := self.temperature) is None: - temperature = default_sampling_params.get("temperature", 0.7) + temperature = default_sampling_params.get("temperature", 1.0) if (top_p := self.top_p) is None: top_p = default_sampling_params.get("top_p", 1.0) if (top_k := self.top_k) is None: @@ -705,7 +706,7 @@ def to_beam_search_params( default_sampling_params = {} n = self.n if self.n is not None else 1 temperature = self.temperature or default_sampling_params.get( - "temperature", 0.0) + "temperature", 1.0) return BeamSearchParams( beam_width=n, From 35205eb93b1df5747f1bba3b81961b75d51358ad Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Mon, 16 Dec 2024 11:12:27 +0800 Subject: [PATCH 14/16] fix beam Signed-off-by: liuyanyi --- vllm/entrypoints/openai/protocol.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6354fb6cb6895..f261c9e747c80 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -705,8 +705,9 @@ def to_beam_search_params( if default_sampling_params is None: default_sampling_params = {} n = self.n if self.n is not None else 1 - temperature = self.temperature or default_sampling_params.get( - "temperature", 1.0) + + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 1.0) return BeamSearchParams( beam_width=n, From d9fdb3b4fc4354cd91b6277e081f05939f0b8647 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Mon, 16 Dec 2024 11:23:01 +0800 Subject: [PATCH 15/16] default 1.0 Signed-off-by: liuyanyi --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f261c9e747c80..2b75238bc4b30 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -363,7 +363,7 @@ def to_beam_search_params( n = self.n if self.n is not None else 1 if (temperature := self.temperature) is None: - temperature = default_sampling_params.get("temperature", 0.0) + temperature = default_sampling_params.get("temperature", 1.0) return BeamSearchParams( beam_width=n, From 69819cac96192520a1e4afa08b564de906de2afc Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Tue, 17 Dec 2024 11:19:50 +0800 Subject: [PATCH 16/16] [Enhancement] Add default sampling parameters for chat and completion requests Signed-off-by: liuyanyi --- tests/entrypoints/openai/test_chat.py | 1 + vllm/entrypoints/openai/protocol.py | 57 +++++++++++++++++++++------ 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 8d23a2be6f9bb..5a3f4653a42ed 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -496,6 +496,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI, model=MODEL_NAME, messages=messages, max_completion_tokens=10, + temperature=0.0, # to ensure deterministic results extra_body=dict(guided_choice=sample_guided_choice, guided_decoding_backend=guided_decoding_backend)) choice2 = chat_completion.choices[0].message.content diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2b75238bc4b30..7bcde4ec8a1cf 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -348,6 +348,15 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params + # Default sampling parameters for chat completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + } + def to_beam_search_params( self, default_max_tokens: int, @@ -363,7 +372,8 @@ def to_beam_search_params( n = self.n if self.n is not None else 1 if (temperature := self.temperature) is None: - temperature = default_sampling_params.get("temperature", 1.0) + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) return BeamSearchParams( beam_width=n, @@ -387,16 +397,22 @@ def to_sampling_params( default_sampling_params = {} # Default parameters if (repetition_penalty := self.repetition_penalty) is None: - repetition_penalty = (default_sampling_params.get( - "repetition_penalty", 1.0)) + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) if (temperature := self.temperature) is None: - temperature = default_sampling_params.get("temperature", 1.0) + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) if (top_p := self.top_p) is None: - top_p = default_sampling_params.get("top_p", 1.0) + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) if (top_k := self.top_k) is None: - top_k = default_sampling_params.get("top_k", -1) + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) if (min_p := self.min_p) is None: - min_p = default_sampling_params.get("min_p", 0.0) + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: @@ -693,6 +709,15 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params + # Default sampling parameters for completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + } + def to_beam_search_params( self, default_max_tokens: int, @@ -730,16 +755,22 @@ def to_sampling_params( default_sampling_params = {} # Default parameters if (repetition_penalty := self.repetition_penalty) is None: - repetition_penalty = (default_sampling_params.get( - "repetition_penalty", 1.0)) + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) if (temperature := self.temperature) is None: - temperature = default_sampling_params.get("temperature", 1.0) + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) if (top_p := self.top_p) is None: - top_p = default_sampling_params.get("top_p", 1.0) + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) if (top_k := self.top_k) is None: - top_k = default_sampling_params.get("top_k", -1) + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) if (min_p := self.min_p) is None: - min_p = default_sampling_params.get("min_p", 0.0) + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: