Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add load generation config from model #11164

Merged
merged 20 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions examples/offline_inference_with_default_generation_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from vllm import LLM

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create an LLM with built-in default generation config.
# The generation config is set to None by default to keep
# the behavior consistent with the previous version.
# If you want to use the default generation config from the model,
# you should set the generation_config to "auto".
llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct", generation_config="auto")

# Load the default sampling parameters from the model.
sampling_params = llm.get_default_sampling_params()
# Modify the sampling parameters if needed.
sampling_params.temperature = 0.5

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
61 changes: 61 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from contextlib import suppress
from dataclasses import dataclass
from typing import Optional
from unittest.mock import MagicMock

from vllm.config import MultiModalConfig
Expand Down Expand Up @@ -31,6 +32,10 @@ class MockModelConfig:
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}


@dataclass
Expand Down Expand Up @@ -94,3 +99,59 @@ def test_serving_chat_should_set_correct_max_tokens():
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 10


def test_serving_chat_could_load_correct_generation_config():

mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"temperature": 0.5,
"repetition_penalty": 1.05
}

mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

# Initialize the serving chat
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].temperature == 0.5
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

# Test the param when user set it
req.temperature = 0.1

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].temperature == 0.1
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05

# Test When temperature==0.0
req.temperature = 0.0

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
59 changes: 57 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname)
Expand Down Expand Up @@ -160,6 +161,7 @@ class ModelConfig:
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
"""

def __init__(self,
Expand Down Expand Up @@ -194,7 +196,8 @@ def __init__(self,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None) -> None:
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand Down Expand Up @@ -322,6 +325,8 @@ def __init__(self,
self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern

self.generation_config = generation_config

self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
Expand Down Expand Up @@ -787,6 +792,56 @@ def get_multimodal_config(self) -> "MultiModalConfig":

return self.multimodal_config

def try_get_generation_config(self) -> Dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
config = try_get_generation_config(
self.model,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
)
else:
config = try_get_generation_config(
self.generation_config,
trust_remote_code=self.trust_remote_code,
)

if config is None:
return {}

return config.to_diff_dict()

def get_diff_sampling_param(self) -> Dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.

Returns:
Dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
"""
if self.generation_config is None:
# When generation_config is not set
return {}
config = self.try_get_generation_config()
available_params = [
"repetition_penalty",
"temperature",
"top_k",
"top_p",
"min_p",
]
Comment on lines +855 to +861
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampling_params.update_from_generation_config(

I think token_ids has been used in llm_engine

if any(p in config for p in available_params):
diff_sampling_param = {
p: config.get(p)
for p in available_params if config.get(p) is not None
}
else:
diff_sampling_param = {}
return diff_sampling_param

@property
def is_encoder_decoder(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
Expand Down
15 changes: 14 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ class EngineArgs:

kv_transfer_config: Optional[KVTransferConfig] = None

generation_config: Optional[str] = None

def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
Expand Down Expand Up @@ -941,6 +943,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default="auto",
help='The worker class to use for distributed execution.')

parser.add_argument(
"--generation-config",
type=nullable_str,
default=None,
help="The folder path to the generation config. "
"Defaults to None, will use the default generation config in vLLM. "
"If set to 'auto', the generation config will be automatically "
"loaded from model. If set to a folder path, the generation config "
"will be loaded from the specified folder path.")

return parser

@classmethod
Expand Down Expand Up @@ -984,7 +996,8 @@ def create_model_config(self) -> ModelConfig:
mm_cache_preprocessor=self.mm_cache_preprocessor,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern)
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config)

def create_load_config(self) -> LoadConfig:
return LoadConfig(
Expand Down
23 changes: 4 additions & 19 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload

Expand Down Expand Up @@ -52,7 +52,6 @@
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
Expand All @@ -65,20 +64,6 @@
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5


def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)

if config is None:
return {}

return config.to_diff_dict()


_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)

Expand Down Expand Up @@ -274,8 +259,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)

self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
self.model_config)
self.generation_config_fields = (
self.model_config.try_get_generation_config())

self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
Expand Down
9 changes: 8 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)

def get_default_sampling_params(self) -> SamplingParams:
diff_sampling_param = (
self.llm_engine.model_config.get_diff_sampling_param())
if diff_sampling_param:
return SamplingParams.from_optional(**diff_sampling_param)
return SamplingParams()

@overload
def generate(
self,
Expand Down Expand Up @@ -441,7 +448,7 @@ def generate(

if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
sampling_params = self.get_default_sampling_params()

self._validate_and_add_requests(
prompts=parsed_prompts,
Expand Down
Loading
Loading