From dc26a3c567153a003c494d0961ddc9779e860f35 Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Tue, 12 Nov 2024 04:02:29 +0900 Subject: [PATCH] Add feature flags to turn it off by default --- tests/prefix_caching/test_prefix_caching.py | 3 ++- vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/cli_args.py | 5 +++++ vllm/entrypoints/openai/run_batch.py | 6 ++++++ vllm/entrypoints/openai/serving_chat.py | 9 ++++++--- 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 55d1ad90b4599..50723dbb610ac 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -72,7 +72,8 @@ def test_mixed_requests( block_size) * block_size else: expected_num_cached_tokens = 0 - assert req_outputs[i].num_cached_tokens == expected_num_cached_tokens + assert req_outputs[ + i].num_cached_tokens == expected_num_cached_tokens vllm_outputs = [ (output.prompt_token_ids + list(output.outputs[0].token_ids), diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 917b347ff1161..2302ab4752a4a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -525,6 +525,7 @@ def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.task == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a089985ac9758..7ace59d4beed9 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") return parser diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 0d016d949d22b..1b422a93263b2 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -78,6 +78,11 @@ def parse_args(): help="Port number for the Prometheus metrics server " "(only needed if enable-metrics is set).", ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") return parser.parse_args() @@ -217,6 +222,7 @@ async def main(args): prompt_adapters=None, request_logger=request_logger, chat_template=None, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.task == "generate" else None openai_serving_embedding = OpenAIServingEmbedding( engine, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 4b644570e20e4..74867d8de8843 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -49,7 +49,8 @@ def __init__(self, chat_template: Optional[str], return_tokens_as_token_ids: bool = False, enable_auto_tools: bool = False, - tool_parser: Optional[str] = None): + tool_parser: Optional[str] = None, + enable_prompt_tokens_details: bool = False): super().__init__(engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, @@ -80,6 +81,8 @@ def __init__(self, f"tool_parser:'{tool_parser}' which has not " "been registered") from e + self.enable_prompt_tokens_details = enable_prompt_tokens_details + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -536,7 +539,7 @@ async def chat_completion_stream_generator( completion_tokens=completion_tokens, total_tokens=num_prompt_tokens + completion_tokens) - if num_cached_tokens: + if self.enable_prompt_tokens_details and num_cached_tokens: final_usage.prompt_tokens_details = PromptTokenUsageInfo( cached_tokens=num_cached_tokens) @@ -710,7 +713,7 @@ async def chat_completion_full_generator( completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens) - if final_res.num_cached_tokens: + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: usage.prompt_tokens_details = PromptTokenUsageInfo( cached_tokens=final_res.num_cached_tokens)