Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] allow candidate compile sizes #10984

Merged
merged 5 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def test_compilation_config():
args = parser.parse_args(["-O=3"])
assert args.compilation_config.level == 3

# set to json
args = parser.parse_args(["--compilation-config", '{"level": 3}'])
# set to string form of a dict
args = parser.parse_args(["--compilation-config", "{'level': 3}"])
assert args.compilation_config.level == 3

# set to json
args = parser.parse_args(['--compilation-config={"level": 3}'])
# set to string form of a dict
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously this is a json string, but i find json string is too restricted, e.g. we need to use double quotes, we cannot have trailing comma, etc.

therefore here we switch to string form of a python dict.

args = parser.parse_args(["--compilation-config={'level': 3}"])
assert args.compilation_config.level == 3


Expand Down
44 changes: 22 additions & 22 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import copy
import enum
import hashlib
Expand Down Expand Up @@ -2191,14 +2192,10 @@ class CompilationConfig(BaseModel):
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for different sizes specified
in inductor_compile_sizes, using configurations
is compiled. In addition, compile for cudagraph sizes that are
in candidate_compile_sizes, using configurations
in inductor_compile_config.
- inductor_compile_sizes: sizes to compile for inductor.
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
to specialize inductor for cudagraph sizes no more than the
specified size. It is useful when we want to specialize inductor
with a subset of cudagraph sizes.
- candidate_compile_sizes: sizes to compile for inductor.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
Expand Down Expand Up @@ -2227,8 +2224,7 @@ class CompilationConfig(BaseModel):
])

use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
inductor_compile_sizes: Optional[List[int]] = Field(default=None)
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict)

Expand Down Expand Up @@ -2294,7 +2290,9 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value))
return CompilationConfig.model_validate_json(cli_value)
# do not use `eval`, it is dangerous and can execute arbitrary code
dict_value = ast.literal_eval(cli_value)
return CompilationConfig.model_validate(dict_value)

def model_post_init(self, __context: Any) -> None:

Expand Down Expand Up @@ -2355,18 +2353,20 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
sizes_to_specialize, self.cudagraph_capture_sizes)
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
assert self.inductor_compile_sizes is None, (
"inductor_compile_sizes should be None when "
"inductor_specialize_for_cudagraph_no_more_than is not None")
self.compile_sizes = [
x for x in self.capture_sizes
if x <= self.inductor_specialize_for_cudagraph_no_more_than
]
else:
if self.inductor_compile_sizes is None:
self.inductor_compile_sizes = []
self.compile_sizes = self.inductor_compile_sizes

if self.candidate_compile_sizes is None:
self.candidate_compile_sizes = []
self.compile_sizes = [
x for x in self.candidate_compile_sizes if x in self.capture_sizes
]
ignored_sizes = [
x for x in self.candidate_compile_sizes
if x not in self.capture_sizes
]
if ignored_sizes:
logger.warning(("candidate_compile_sizes %s are ignored "
"because they are not cudagraph capture sizes."),
ignored_sizes)

# sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True)
Expand Down
5 changes: 1 addition & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,9 @@ def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
# CompilationConfig object
if isinstance(self.compilation_config, (int)):
if isinstance(self.compilation_config, (int, dict)):
self.compilation_config = CompilationConfig.from_cli(
str(self.compilation_config))
elif isinstance(self.compilation_config, (dict)):
self.compilation_config = CompilationConfig.from_cli(
json.dumps(self.compilation_config))

# Setup plugins
from vllm.plugins import load_general_plugins
Expand Down
6 changes: 1 addition & 5 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
import json
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Expand Down Expand Up @@ -186,12 +185,9 @@ def __init__(
kwargs["disable_log_stats"] = True

if compilation_config is not None:
if isinstance(compilation_config, (int)):
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
str(compilation_config))
elif isinstance(compilation_config, (dict)):
compilation_config_instance = CompilationConfig.from_cli(
json.dumps(compilation_config))
else:
compilation_config_instance = compilation_config
else:
Expand Down
Loading