From 5aef49806da2e6cc8a92c948d44e8a722469135f Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Thu, 19 Dec 2024 18:50:38 +0800 Subject: [PATCH] [Feature] Add load generation config from model (#11164) Signed-off-by: liuyanyi Signed-off-by: Yanyi Liu Signed-off-by: DarkLight1337 Co-authored-by: Cyrus Leung --- ...nference_with_default_generation_config.py | 30 ++++ tests/entrypoints/openai/test_serving_chat.py | 61 ++++++++ vllm/config.py | 59 +++++++- vllm/engine/arg_utils.py | 15 +- vllm/engine/llm_engine.py | 23 +-- vllm/entrypoints/llm.py | 9 +- vllm/entrypoints/openai/protocol.py | 139 ++++++++++++++---- vllm/entrypoints/openai/serving_chat.py | 12 +- vllm/entrypoints/openai/serving_completion.py | 13 +- vllm/v1/engine/processor.py | 20 +-- 10 files changed, 307 insertions(+), 74 deletions(-) create mode 100644 examples/offline_inference_with_default_generation_config.py diff --git a/examples/offline_inference_with_default_generation_config.py b/examples/offline_inference_with_default_generation_config.py new file mode 100644 index 0000000000000..346bb80b1e23f --- /dev/null +++ b/examples/offline_inference_with_default_generation_config.py @@ -0,0 +1,30 @@ +from vllm import LLM + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create an LLM with built-in default generation config. +# The generation config is set to None by default to keep +# the behavior consistent with the previous version. +# If you want to use the default generation config from the model, +# you should set the generation_config to "auto". +llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", generation_config="auto") + +# Load the default sampling parameters from the model. +sampling_params = llm.get_default_sampling_params() +# Modify the sampling parameters if needed. +sampling_params.temperature = 0.5 + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 5b40a04db15ee..51b255bb2a6db 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,6 +1,7 @@ import asyncio from contextlib import suppress from dataclasses import dataclass +from typing import Optional from unittest.mock import MagicMock from vllm.config import MultiModalConfig @@ -31,6 +32,10 @@ class MockModelConfig: multimodal_config = MultiModalConfig() hf_config = MockHFConfig() logits_processor_pattern = None + diff_sampling_param: Optional[dict] = None + + def get_diff_sampling_param(self): + return self.diff_sampling_param or {} @dataclass @@ -94,3 +99,59 @@ def test_serving_chat_should_set_correct_max_tokens(): asyncio.run(serving_chat.create_chat_completion(req)) assert mock_engine.generate.call_args.args[1].max_tokens == 10 + + +def test_serving_chat_could_load_correct_generation_config(): + + mock_model_config = MockModelConfig() + mock_model_config.diff_sampling_param = { + "temperature": 0.5, + "repetition_penalty": 1.05 + } + + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + # Initialize the serving chat + serving_chat = OpenAIServingChat(mock_engine, + mock_model_config, + BASE_MODEL_PATHS, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + lora_modules=None, + prompt_adapters=None, + request_logger=None) + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + guided_decoding_backend="outlines", + ) + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].temperature == 0.5 + assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + + # Test the param when user set it + req.temperature = 0.1 + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].temperature == 0.1 + assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + + # Test When temperature==0.0 + req.temperature = 0.0 + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].temperature == 0.0 + assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 diff --git a/vllm/config.py b/vllm/config.py index 9acc3efa4816c..0e886e18fcd6d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,7 +27,8 @@ from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, - get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) + get_sentence_transformer_tokenizer_config, is_encoder_decoder, + try_get_generation_config, uses_mrope) from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, get_cpu_memory, print_warning_once, random_uuid, resolve_obj_by_qualname) @@ -160,6 +161,7 @@ class ModelConfig: logits processor qualified names that can be passed with the `logits_processors` extra completion argument. Defaults to None, which allows no processors. + generation_config: Configuration parameter file for generation. """ def compute_hash(self) -> str: @@ -218,7 +220,8 @@ def __init__(self, disable_mm_preprocessor_cache: bool = False, override_neuron_config: Optional[Dict[str, Any]] = None, override_pooler_config: Optional["PoolerConfig"] = None, - logits_processor_pattern: Optional[str] = None) -> None: + logits_processor_pattern: Optional[str] = None, + generation_config: Optional[str] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -348,6 +351,8 @@ def __init__(self, self.pooler_config = self._init_pooler_config(override_pooler_config) self.logits_processor_pattern = logits_processor_pattern + self.generation_config = generation_config + self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -813,6 +818,56 @@ def get_multimodal_config(self) -> "MultiModalConfig": return self.multimodal_config + def try_get_generation_config(self) -> Dict[str, Any]: + if self.generation_config is None or self.generation_config == "auto": + config = try_get_generation_config( + self.model, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + ) + else: + config = try_get_generation_config( + self.generation_config, + trust_remote_code=self.trust_remote_code, + ) + + if config is None: + return {} + + return config.to_diff_dict() + + def get_diff_sampling_param(self) -> Dict[str, Any]: + """ + This method returns a dictionary containing the parameters + that differ from the default sampling parameters, but only + if `generation_config` is set. If `generation_config` is not + set, an empty dictionary is returned. + + Returns: + Dict[str, Any]: A dictionary with the differing sampling + parameters if `generation_config` is set, otherwise an + empty dictionary. + """ + if self.generation_config is None: + # When generation_config is not set + return {} + config = self.try_get_generation_config() + available_params = [ + "repetition_penalty", + "temperature", + "top_k", + "top_p", + "min_p", + ] + if any(p in config for p in available_params): + diff_sampling_param = { + p: config.get(p) + for p in available_params if config.get(p) is not None + } + else: + diff_sampling_param = {} + return diff_sampling_param + @property def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 75e79d509d2e1..912a8b2f54adb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -197,6 +197,8 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None + generation_config: Optional[str] = None + def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model @@ -942,6 +944,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default="auto", help='The worker class to use for distributed execution.') + parser.add_argument( + "--generation-config", + type=nullable_str, + default=None, + help="The folder path to the generation config. " + "Defaults to None, will use the default generation config in vLLM. " + "If set to 'auto', the generation config will be automatically " + "loaded from model. If set to a folder path, the generation config " + "will be loaded from the specified folder path.") + return parser @classmethod @@ -985,7 +997,8 @@ def create_model_config(self) -> ModelConfig: disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, - logits_processor_pattern=self.logits_processor_pattern) + logits_processor_pattern=self.logits_processor_pattern, + generation_config=self.generation_config) def create_load_config(self) -> LoadConfig: return LoadConfig( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dc2d77d6927cd..e78b6f4d26758 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,8 +5,8 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, - Iterable, List, Mapping, NamedTuple, Optional) +from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable, + List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence from typing import Set, Type, Union, cast, overload @@ -52,7 +52,6 @@ SequenceGroupOutput, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) -from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import ( @@ -65,20 +64,6 @@ logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 - -def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: - config = try_get_generation_config( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.revision, - ) - - if config is None: - return {} - - return config.to_diff_dict() - - _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput) @@ -274,8 +259,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: return tokenizer_group.get_lora_tokenizer(sequence.lora_request) self.seq_counter = Counter() - self.generation_config_fields = _load_generation_config_dict( - self.model_config) + self.generation_config_fields = ( + self.model_config.try_get_generation_config()) self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 58ab892676b9a..94d4a4d89adc9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -258,6 +258,13 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: else: tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) + def get_default_sampling_params(self) -> SamplingParams: + diff_sampling_param = ( + self.llm_engine.model_config.get_diff_sampling_param()) + if diff_sampling_param: + return SamplingParams.from_optional(**diff_sampling_param) + return SamplingParams() + @overload def generate( self, @@ -441,7 +448,7 @@ def generate( if sampling_params is None: # Use default sampling params. - sampling_params = SamplingParams() + sampling_params = self.get_default_sampling_params() self._validate_and_add_requests( prompts=parsed_prompts, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 5a70e0952666b..1314de714215e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -211,8 +211,8 @@ class ChatCompletionRequest(OpenAIBaseModel): stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None - temperature: Optional[float] = 1.0 - top_p: Optional[float] = 1.0 + temperature: Optional[float] = None + top_p: Optional[float] = None tools: Optional[List[ChatCompletionToolsParam]] = None tool_choice: Optional[Union[Literal["none"], Literal["auto"], ChatCompletionNamedToolChoiceParam]] = "none" @@ -224,9 +224,9 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: begin-chat-completion-sampling-params best_of: Optional[int] = None use_beam_search: bool = False - top_k: int = -1 - min_p: float = 0.0 - repetition_penalty: float = 1.0 + top_k: Optional[int] = None + min_p: Optional[float] = None + repetition_penalty: Optional[float] = None length_penalty: float = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) include_stop_str_in_output: bool = False @@ -348,15 +348,32 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_beam_search_params(self, - default_max_tokens: int) -> BeamSearchParams: + # Default sampling parameters for chat completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + } + + def to_beam_search_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None + ) -> BeamSearchParams: # TODO(#9845): remove max_tokens when field is removed from OpenAI API max_tokens = self.max_completion_tokens or self.max_tokens if max_tokens is None: max_tokens = default_max_tokens + if default_sampling_params is None: + default_sampling_params = {} n = self.n if self.n is not None else 1 - temperature = self.temperature if self.temperature is not None else 0.0 + + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) return BeamSearchParams( beam_width=n, @@ -367,13 +384,36 @@ def to_beam_search_params(self, include_stop_str_in_output=self.include_stop_str_in_output) def to_sampling_params( - self, default_max_tokens: int, - logits_processor_pattern: Optional[str]) -> SamplingParams: + self, + default_max_tokens: int, + logits_processor_pattern: Optional[str], + default_sampling_params: Optional[dict] = None) -> SamplingParams: # TODO(#9845): remove max_tokens when field is removed from OpenAI API max_tokens = self.max_completion_tokens or self.max_tokens if max_tokens is None: max_tokens = default_max_tokens + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs @@ -403,11 +443,11 @@ def to_sampling_params( best_of=self.best_of, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - min_p=self.min_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, @@ -584,15 +624,15 @@ class CompletionRequest(OpenAIBaseModel): stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None - temperature: Optional[float] = 1.0 - top_p: Optional[float] = 1.0 + temperature: Optional[float] = None + top_p: Optional[float] = None user: Optional[str] = None # doc: begin-completion-sampling-params use_beam_search: bool = False - top_k: int = -1 - min_p: float = 0.0 - repetition_penalty: float = 1.0 + top_k: Optional[int] = None + min_p: Optional[float] = None + repetition_penalty: Optional[float] = None length_penalty: float = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) include_stop_str_in_output: bool = False @@ -669,14 +709,30 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_beam_search_params(self, - default_max_tokens: int) -> BeamSearchParams: + # Default sampling parameters for completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": -1, + "min_p": 0.0, + } + + def to_beam_search_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None + ) -> BeamSearchParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens + if default_sampling_params is None: + default_sampling_params = {} n = self.n if self.n is not None else 1 - temperature = self.temperature if self.temperature is not None else 0.0 + + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 1.0) return BeamSearchParams( beam_width=n, @@ -687,12 +743,35 @@ def to_beam_search_params(self, include_stop_str_in_output=self.include_stop_str_in_output) def to_sampling_params( - self, default_max_tokens: int, - logits_processor_pattern: Optional[str]) -> SamplingParams: + self, + default_max_tokens: int, + logits_processor_pattern: Optional[str], + default_sampling_params: Optional[dict] = None) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: prompt_logprobs = self.logprobs @@ -718,11 +797,11 @@ def to_sampling_params( best_of=self.best_of, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - min_p=self.min_p, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 81bce0dd370bb..d085333563d19 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -91,6 +91,10 @@ def __init__( "been registered") from e self.enable_prompt_tokens_details = enable_prompt_tokens_details + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info("Overwriting default chat sampling param with: %s", + diff_sampling_param) async def create_chat_completion( self, @@ -191,13 +195,17 @@ async def create_chat_completion( sampling_params: Union[SamplingParams, BeamSearchParams] default_max_tokens = self.max_model_len - len( engine_prompt["prompt_token_ids"]) + # Build default sampling params + default_sampling_params = ( + self.model_config.get_diff_sampling_param()) if request.use_beam_search: sampling_params = request.to_beam_search_params( - default_max_tokens) + default_max_tokens, default_sampling_params) else: sampling_params = request.to_sampling_params( default_max_tokens, - self.model_config.logits_processor_pattern) + self.model_config.logits_processor_pattern, + default_sampling_params) self._log_inputs(request_id, request_prompts[i], diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 5cf9df92e296e..aaad7b8c7f44c 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -55,6 +55,11 @@ def __init__( prompt_adapters=prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) + diff_sampling_param = self.model_config.get_diff_sampling_param() + if diff_sampling_param: + logger.info( + "Overwriting default completion sampling param with: %s", + diff_sampling_param) async def create_completion( self, @@ -118,13 +123,17 @@ async def create_completion( sampling_params: Union[SamplingParams, BeamSearchParams] default_max_tokens = self.max_model_len - len( engine_prompt["prompt_token_ids"]) + # Build default sampling params + default_sampling_params = ( + self.model_config.get_diff_sampling_param()) if request.use_beam_search: sampling_params = request.to_beam_search_params( - default_max_tokens) + default_max_tokens, default_sampling_params) else: sampling_params = request.to_sampling_params( default_max_tokens, - self.model_config.logits_processor_pattern) + self.model_config.logits_processor_pattern, + default_sampling_params) request_id_item = f"{request_id}-{i}" diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 61dce40a584c8..ffcaa158d252d 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, Mapping, Optional, Tuple, Union +from typing import Mapping, Optional, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, ModelConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, @@ -12,7 +12,6 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient @@ -34,8 +33,8 @@ def __init__( self.lora_config = lora_config self.tokenizer = tokenizer - self.generation_config_fields = _load_generation_config_dict( - model_config) + self.generation_config_fields = model_config.try_get_generation_config( + ) self.input_preprocessor = InputPreprocessor(model_config, self.tokenizer, mm_registry) @@ -181,16 +180,3 @@ def _validate_model_inputs(self, inputs: ProcessorInputs): # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens - - -def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: - config = try_get_generation_config( - model_config.model, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.revision, - ) - - if config is None: - return {} - - return config.to_diff_dict()