Skip to content

Commit

Permalink
[Feature] Add load generation config from model (#11164)
Browse files Browse the repository at this point in the history
Signed-off-by: liuyanyi <[email protected]>
Signed-off-by: Yanyi Liu <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
liuyanyi and DarkLight1337 authored Dec 19, 2024
1 parent 9835673 commit 5aef498
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 74 deletions.
30 changes: 30 additions & 0 deletions examples/offline_inference_with_default_generation_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from vllm import LLM

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create an LLM with built-in default generation config.
# The generation config is set to None by default to keep
# the behavior consistent with the previous version.
# If you want to use the default generation config from the model,
# you should set the generation_config to "auto".
llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", generation_config="auto")

# Load the default sampling parameters from the model.
sampling_params = llm.get_default_sampling_params()
# Modify the sampling parameters if needed.
sampling_params.temperature = 0.5

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
61 changes: 61 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from contextlib import suppress
from dataclasses import dataclass
from typing import Optional
from unittest.mock import MagicMock

from vllm.config import MultiModalConfig
Expand Down Expand Up @@ -31,6 +32,10 @@ class MockModelConfig:
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}


@dataclass
Expand Down Expand Up @@ -94,3 +99,59 @@ def test_serving_chat_should_set_correct_max_tokens():
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 10


def test_serving_chat_could_load_correct_generation_config():

mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"temperature": 0.5,
"repetition_penalty": 1.05
}

mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

# Initialize the serving chat
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].temperature == 0.5
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

# Test the param when user set it
req.temperature = 0.1

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

# Test When temperature==0.0
req.temperature = 0.0

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
59 changes: 57 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname)
Expand Down Expand Up @@ -160,6 +161,7 @@ class ModelConfig:
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
"""

def compute_hash(self) -> str:
Expand Down Expand Up @@ -218,7 +220,8 @@ def __init__(self,
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None) -> None:
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand Down Expand Up @@ -348,6 +351,8 @@ def __init__(self,
self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern

self.generation_config = generation_config

self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
Expand Down Expand Up @@ -813,6 +818,56 @@ def get_multimodal_config(self) -> "MultiModalConfig":

return self.multimodal_config

def try_get_generation_config(self) -> Dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
config = try_get_generation_config(
self.model,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
)
else:
config = try_get_generation_config(
self.generation_config,
trust_remote_code=self.trust_remote_code,
)

if config is None:
return {}

return config.to_diff_dict()

def get_diff_sampling_param(self) -> Dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.
Returns:
Dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
"""
if self.generation_config is None:
# When generation_config is not set
return {}
config = self.try_get_generation_config()
available_params = [
"repetition_penalty",
"temperature",
"top_k",
"top_p",
"min_p",
]
if any(p in config for p in available_params):
diff_sampling_param = {
p: config.get(p)
for p in available_params if config.get(p) is not None
}
else:
diff_sampling_param = {}
return diff_sampling_param

@property
def is_encoder_decoder(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
Expand Down
15 changes: 14 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ class EngineArgs:

kv_transfer_config: Optional[KVTransferConfig] = None

generation_config: Optional[str] = None

def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
Expand Down Expand Up @@ -942,6 +944,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default="auto",
help='The worker class to use for distributed execution.')

parser.add_argument(
"--generation-config",
type=nullable_str,
default=None,
help="The folder path to the generation config. "
"Defaults to None, will use the default generation config in vLLM. "
"If set to 'auto', the generation config will be automatically "
"loaded from model. If set to a folder path, the generation config "
"will be loaded from the specified folder path.")

return parser

@classmethod
Expand Down Expand Up @@ -985,7 +997,8 @@ def create_model_config(self) -> ModelConfig:
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern)
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config)

def create_load_config(self) -> LoadConfig:
return LoadConfig(
Expand Down
23 changes: 4 additions & 19 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload

Expand Down Expand Up @@ -52,7 +52,6 @@
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
Expand All @@ -65,20 +64,6 @@
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5


def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)

if config is None:
return {}

return config.to_diff_dict()


_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)

Expand Down Expand Up @@ -274,8 +259,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)

self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
self.model_config)
self.generation_config_fields = (
self.model_config.try_get_generation_config())

self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
Expand Down
9 changes: 8 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)

def get_default_sampling_params(self) -> SamplingParams:
diff_sampling_param = (
self.llm_engine.model_config.get_diff_sampling_param())
if diff_sampling_param:
return SamplingParams.from_optional(**diff_sampling_param)
return SamplingParams()

@overload
def generate(
self,
Expand Down Expand Up @@ -441,7 +448,7 @@ def generate(

if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
sampling_params = self.get_default_sampling_params()

self._validate_and_add_requests(
prompts=parsed_prompts,
Expand Down
Loading

0 comments on commit 5aef498

Please sign in to comment.