Skip to content

Commit

Permalink
[8/N] enable cli flag without a space (vllm-project#10529)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 21, 2024
1 parent e7a8341 commit 7560ae5
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 7 deletions.
4 changes: 2 additions & 2 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_compile_correctness(test_setting: TestSetting):
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
]:
all_args.append(final_args + ["-O", str(level)])
all_args.append(final_args + [f"-O{level}"])
all_envs.append({})

# inductor will change the output, so we only compare if the output
Expand All @@ -121,7 +121,7 @@ def test_compile_correctness(test_setting: TestSetting):
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
]:
all_args.append(final_args + ["-O", str(level)])
all_args.append(final_args + [f"-O{level}"])
all_envs.append({})
if level != CompilationLevel.DYNAMO_ONCE and not fullgraph:
# "DYNAMO_ONCE" will always use fullgraph
Expand Down
28 changes: 28 additions & 0 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,34 @@ def test_limit_mm_per_prompt_parser(arg, expected):
assert args.limit_mm_per_prompt == expected


def test_compilation_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())

# default value
args = parser.parse_args([])
assert args.compilation_config is None

# set to O3
args = parser.parse_args(["-O3"])
assert args.compilation_config.level == 3

# set to O 3 (space)
args = parser.parse_args(["-O", "3"])
assert args.compilation_config.level == 3

# set to O 3 (equals)
args = parser.parse_args(["-O=3"])
assert args.compilation_config.level == 3

# set to json
args = parser.parse_args(["--compilation-config", '{"level": 3}'])
assert args.compilation_config.level == 3

# set to json
args = parser.parse_args(['--compilation-config={"level": 3}'])
assert args.compilation_config.level == 3


def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([
Expand Down
9 changes: 5 additions & 4 deletions tests/tpu/test_custom_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
def test_custom_dispatcher():
compare_two_settings(
"google/gemma-2b",
arg1=["--enforce-eager", "-O",
str(CompilationLevel.DYNAMO_ONCE)],
arg2=["--enforce-eager", "-O",
str(CompilationLevel.DYNAMO_AS_IS)],
arg1=[
"--enforce-eager",
f"-O{CompilationLevel.DYNAMO_ONCE}",
],
arg2=["--enforce-eager", f"-O{CompilationLevel.DYNAMO_AS_IS}"],
env1={},
env2={})
5 changes: 4 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'testing only. level 3 is the recommended level '
'for production.\n'
'To specify the full compilation config, '
'use a JSON string.')
'use a JSON string.\n'
'Following the convention of traditional '
'compilers, using -O without space is also '
'supported. -O3 is equivalent to -O 3.')

return parser

Expand Down
4 changes: 4 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,10 @@ def parse_args(self, args=None, namespace=None):
else:
processed_args.append('--' +
arg[len('--'):].replace('_', '-'))
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
# allow -O flag to be used without space, e.g. -O3
processed_args.append('-O')
processed_args.append(arg[2:])
else:
processed_args.append(arg)

Expand Down

0 comments on commit 7560ae5

Please sign in to comment.