diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 5b0e76fe53685..de78d41ad12eb 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -59,6 +59,25 @@ def test_compilation_config(): assert args.compilation_config.level == 3 +def test_prefix_cache_default(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + args = parser.parse_args([]) + + engine_args = EngineArgs.from_cli_args(args=args) + assert (not engine_args.enable_prefix_caching + ), "prefix caching defaults to off." + + # with flag to turn it on. + args = parser.parse_args(["--enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert engine_args.enable_prefix_caching + + # with disable flag to turn it off. + args = parser.parse_args(["--no-enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert not engine_args.enable_prefix_caching + + def test_valid_pooling_config(): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) args = parser.parse_args([ diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index 69cfdf5a395c1..ac5e7dde525a7 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -4,6 +4,7 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser if not envs.VLLM_USE_V1: pytest.skip( @@ -12,6 +13,24 @@ ) +def test_prefix_caching_from_cli(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + args = parser.parse_args([]) + engine_args = EngineArgs.from_cli_args(args=args) + assert (engine_args.enable_prefix_caching + ), "V1 turns on prefix caching by default." + + # Turn it off possible with flag. + args = parser.parse_args(["--no-enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert not engine_args.enable_prefix_caching + + # Turn it on with flag. + args = parser.parse_args(["--enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert engine_args.enable_prefix_caching + + def test_defaults(): engine_args = EngineArgs(model="facebook/opt-125m") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 90b4798f17a13..f0020562c3c3a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -416,9 +416,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'tokens. This is ignored on neuron devices and ' 'set to max-model-len') - parser.add_argument('--enable-prefix-caching', - action='store_true', - help='Enables automatic prefix caching.') + parser.add_argument( + "--enable-prefix-caching", + action=argparse.BooleanOptionalAction, + default=EngineArgs.enable_prefix_caching, + help="Enables automatic prefix caching. " + "Use --no-enable-prefix-caching to disable explicitly.", + ) parser.add_argument('--disable-sliding-window', action='store_true', help='Disables sliding window, '