Skip to content

Commit

Permalink
Add feature flags to turn it off by default
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeitong committed Nov 11, 2024
1 parent 853b770 commit dc26a3c
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 4 deletions.
3 changes: 2 additions & 1 deletion tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def test_mixed_requests(
block_size) * block_size
else:
expected_num_cached_tokens = 0
assert req_outputs[i].num_cached_tokens == expected_num_cached_tokens
assert req_outputs[
i].num_cached_tokens == expected_num_cached_tokens

vllm_outputs = [
(output.prompt_token_ids + list(output.outputs[0].token_ids),
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
Expand Down
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
parser.add_argument(
"--enable-prompt-tokens-details",
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")

return parser

Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def parse_args():
help="Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--enable-prompt-tokens-details",
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")

return parser.parse_args()

Expand Down Expand Up @@ -217,6 +222,7 @@ async def main(args):
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding(
engine,
Expand Down
9 changes: 6 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(self,
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None):
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False):
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
Expand Down Expand Up @@ -80,6 +81,8 @@ def __init__(self,
f"tool_parser:'{tool_parser}' which has not "
"been registered") from e

self.enable_prompt_tokens_details = enable_prompt_tokens_details

async def create_chat_completion(
self,
request: ChatCompletionRequest,
Expand Down Expand Up @@ -536,7 +539,7 @@ async def chat_completion_stream_generator(
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens)
if num_cached_tokens:
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)

Expand Down Expand Up @@ -710,7 +713,7 @@ async def chat_completion_full_generator(
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens +
num_generated_tokens)
if final_res.num_cached_tokens:
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)

Expand Down

0 comments on commit dc26a3c

Please sign in to comment.