From 3f8864801aa81a935df2014107e80cbb73bb14c3 Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Sat, 9 Nov 2024 06:16:28 +0900 Subject: [PATCH 1/7] WIP Signed-off-by: Zifei Tong --- vllm/entrypoints/openai/protocol.py | 5 +++++ vllm/outputs.py | 13 ++++++++++--- vllm/sequence.py | 13 +++++++++++-- vllm/worker/model_runner.py | 4 +++- 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 1335e51bd152c..88b5216f984c2 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -99,10 +99,15 @@ class ModelList(OpenAIBaseModel): data: List[ModelCard] = Field(default_factory=list) +class PromptTokenUsageInfo(OpenAIBaseModel): + cached_tokens: Optional[int] = None + + class UsageInfo(OpenAIBaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 + prompt_tokens_details: Optional[PromptTokenUsageInfo] = None class RequestResponseMetadata(BaseModel): diff --git a/vllm/outputs.py b/vllm/outputs.py index 951976310e7ae..e76e44573c4db 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -83,10 +83,11 @@ class RequestOutput: finished: Whether the whole request is finished. metrics: Metrics associated with the request. lora_request: The LoRA request that was used to generate the output. - encoder_prompt: The encoder prompt string of the request; + encoder_prompt: The encoder prompt string of the request; None if decoder-only encoder_prompt_token_ids: The token IDs of the encoder prompt; None if decoder-only + num_cached_tokens: The number of tokens with prefix cache hit. """ def __init__( @@ -101,6 +102,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, encoder_prompt: Optional[str] = None, encoder_prompt_token_ids: Optional[List[int]] = None, + num_cached_tokens: Optional[int] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -112,6 +114,7 @@ def __init__( self.lora_request = lora_request self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids + self.num_cached_tokens = num_cached_tokens @classmethod def from_seq_group( @@ -162,6 +165,8 @@ def from_seq_group( outputs = [] include_prompt = True + # num_cached_tokens should be the same for all the sequences + num_cached_tokens = None for i, seq in enumerate(top_n_seqs): output_text = seq.get_output_text_to_return( text_buffer_length, delta) @@ -169,6 +174,7 @@ def from_seq_group( output_token_ids = seq.get_output_token_ids_to_return(delta) num_output_tokens = 1 if isinstance(output_token_ids, int) else len(output_token_ids) + num_cached_tokens = seq.data.get_num_cached_tokens() output_logprobs = seq.output_logprobs if include_logprobs else None @@ -242,7 +248,7 @@ def from_seq_group( init_args = (seq_group.request_id, prompt, prompt_token_ids, prompt_logprobs, outputs, finished, seq_group.metrics, seq_group.lora_request, encoder_prompt, - encoder_prompt_token_ids) + encoder_prompt_token_ids, num_cached_tokens) if use_cache: request_output = seq_group.cached_request_output @@ -263,7 +269,8 @@ def __repr__(self) -> str: f"outputs={self.outputs}, " f"finished={self.finished}, " f"metrics={self.metrics}, " - f"lora_request={self.lora_request})") + f"lora_request={self.lora_request}, " + f"num_cached_tokens={self.num_cached_tokens})") class EmbeddingRequestOutput: diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d7ddc7ec4447..4b697ebf2a687 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -167,6 +167,7 @@ class SequenceData(msgspec.Struct, ...] = msgspec.field(default_factory=tuple) # The number of tokens that are computed (that run against the model). _num_computed_tokens: int = 0 + _num_cached_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL _cached_all_token_ids: List[int] = msgspec.field(default_factory=list) @@ -323,6 +324,14 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): if self.get_num_uncomputed_tokens() == 0: self._stage = SequenceStage.DECODE + def get_num_cached_tokens(self) -> int: + """Return the number of tokens with prefix cache hit.""" + return self._num_cached_tokens + + def update_num_cached_tokens(self, num_cached_tokens: int): + """Update the number of tokens with prefix cache hit.""" + self._num_cached_tokens = num_cached_tokens + def reset_state_for_recompute(self) -> None: """Reset the number of computed tokens from this sequence. It is supposed to be called when a sequence needs to be started from @@ -379,7 +388,7 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. - + The sequence is constructed from the :data:`DecoderOnlyInputs` (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder) instance passed in through the :code:`inputs` constructor argument. @@ -906,7 +915,7 @@ class SequenceGroupMetadata( multi_modal_data: Multi modal data. mm_processor_kwargs: Multimodal input processor / mapper overrides. encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None + (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. cross_block_table: Optional cross-attention block table associated diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a1ec2e85be7b8..de8041f97a347 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -533,7 +533,6 @@ def _compute_for_prefix_cache_hit( and self.sliding_window is None and inter_data.is_prompt) inter_data.prefix_cache_hit = prefix_cache_hit - if not prefix_cache_hit: return @@ -542,6 +541,9 @@ def _compute_for_prefix_cache_hit( # this may be larger than the sequence length if chunked # prefill is enabled. prefix_cache_len = len(computed_block_nums) * self.block_size + seq_group_metadata.seq_data[inter_data.seq_ids[ + seq_idx]].update_num_cached_tokens(prefix_cache_len) + # The number of so far computed prompt tokens in this sequence. context_len = inter_data.context_lens[seq_idx] # The total number of prompt tokens in this sequence. From a001d6b86fb69b9825cf8214b74a087a22a0dbff Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Sat, 9 Nov 2024 07:51:40 +0900 Subject: [PATCH 2/7] savE Signed-off-by: Zifei Tong --- vllm/entrypoints/openai/serving_chat.py | 10 ++++++++-- vllm/outputs.py | 8 ++++---- vllm/worker/model_runner.py | 1 + 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9551b4f2091dd..d786fb6b718c0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -18,8 +18,8 @@ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, - ToolCall, UsageInfo) + DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, + RequestResponseMetadata, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, OpenAIServing, @@ -252,6 +252,7 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices num_prompt_tokens = 0 + num_cached_tokens = None if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -305,6 +306,7 @@ async def chat_completion_stream_generator( # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). if first_iteration: + num_cached_tokens = res.num_cached_tokens # Send first response for each request.n (index) with # the role role = self.get_chat_request_role(request) @@ -534,6 +536,8 @@ async def chat_completion_stream_generator( prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, total_tokens=num_prompt_tokens + completion_tokens, + prompt_tokens_details=PromptTokenUsageInfo( + cached_tokens=num_cached_tokens), ) final_usage_chunk = ChatCompletionStreamResponse( @@ -706,6 +710,8 @@ async def chat_completion_full_generator( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, + prompt_tokens_details=PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens), ) request_metadata.final_usage_info = usage diff --git a/vllm/outputs.py b/vllm/outputs.py index e76e44573c4db..7b1585503059f 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -83,10 +83,10 @@ class RequestOutput: finished: Whether the whole request is finished. metrics: Metrics associated with the request. lora_request: The LoRA request that was used to generate the output. - encoder_prompt: The encoder prompt string of the request; - None if decoder-only - encoder_prompt_token_ids: The token IDs of the encoder prompt; - None if decoder-only + encoder_prompt: The encoder prompt string of the request. + None if decoder-only. + encoder_prompt_token_ids: The token IDs of the encoder prompt. + None if decoder-only. num_cached_tokens: The number of tokens with prefix cache hit. """ diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index de8041f97a347..f6d5f2115e9c1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -533,6 +533,7 @@ def _compute_for_prefix_cache_hit( and self.sliding_window is None and inter_data.is_prompt) inter_data.prefix_cache_hit = prefix_cache_hit + if not prefix_cache_hit: return From 373be32fdfe2834daff2c83b78617b9394f7430f Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Sat, 9 Nov 2024 09:00:55 +0900 Subject: [PATCH 3/7] Add tests Signed-off-by: Zifei Tong --- tests/prefix_caching/test_prefix_caching.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index fd6564bbfe630..ef6a18d719457 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -27,6 +27,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("cached_position", [0, 1]) +@pytest.mark.parametrize("block_size", [16]) def test_mixed_requests( hf_runner, vllm_runner, @@ -36,11 +37,12 @@ def test_mixed_requests( dtype: str, max_tokens: int, cached_position: int, + block_size: int, monkeypatch, ) -> None: """ Test the case when some sequences have the prefix cache hit - and the others don't. The cached position determines where + and the others don't. The cached position determines where the sequence is at among the batch of prefills. """ override_backend_env_variable(monkeypatch, backend) @@ -53,12 +55,26 @@ def test_mixed_requests( model, dtype=dtype, enable_prefix_caching=True, + block_size=block_size, ) as vllm_model: # Run the first prompt so the cache is populated vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) # Run all the promopts - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + req_outputs = vllm_model.model.generate(example_prompts, greedy_params) + + # Verify number of cached tokens + expected_num_cached_tokens = ( + len(req_outputs[cached_position].prompt_token_ids) // + block_size) * block_size + assert req_outputs[ + cached_position].num_cached_tokens == expected_num_cached_tokens + + vllm_outputs = [ + (output.prompt_token_ids + list(output.outputs[0].token_ids), + output.prompt + output.outputs[0].text) for output in req_outputs + ] check_outputs_equal( outputs_0_lst=hf_outputs, From 593e88f9d6b601ebc33c722b022bd55eb08a4f4c Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Tue, 12 Nov 2024 02:51:52 +0900 Subject: [PATCH 4/7] don't add prompt_tokens_details if no token is cached Signed-off-by: Zifei Tong --- vllm/entrypoints/openai/serving_chat.py | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d786fb6b718c0..4b644570e20e4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -532,13 +532,13 @@ async def chat_completion_stream_generator( # is sent, send the usage if include_usage: completion_tokens = sum(previous_num_tokens) - final_usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens, - prompt_tokens_details=PromptTokenUsageInfo( - cached_tokens=num_cached_tokens), - ) + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + if num_cached_tokens: + final_usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens) final_usage_chunk = ChatCompletionStreamResponse( id=request_id, @@ -706,13 +706,13 @@ async def chat_completion_full_generator( num_prompt_tokens += len(final_res.encoder_prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - prompt_tokens_details=PromptTokenUsageInfo( - cached_tokens=final_res.num_cached_tokens), - ) + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) + if final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) request_metadata.final_usage_info = usage From cc5f5d5d691a57fa8d9cbe359a3562d8d139e00c Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Tue, 12 Nov 2024 03:46:44 +0900 Subject: [PATCH 5/7] update tests Signed-off-by: Zifei Tong --- tests/prefix_caching/test_prefix_caching.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index ef6a18d719457..55d1ad90b4599 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -65,11 +65,14 @@ def test_mixed_requests( req_outputs = vllm_model.model.generate(example_prompts, greedy_params) # Verify number of cached tokens - expected_num_cached_tokens = ( - len(req_outputs[cached_position].prompt_token_ids) // - block_size) * block_size - assert req_outputs[ - cached_position].num_cached_tokens == expected_num_cached_tokens + for i in range(len(req_outputs)): + if i == cached_position: + expected_num_cached_tokens = ( + len(req_outputs[i].prompt_token_ids) // + block_size) * block_size + else: + expected_num_cached_tokens = 0 + assert req_outputs[i].num_cached_tokens == expected_num_cached_tokens vllm_outputs = [ (output.prompt_token_ids + list(output.outputs[0].token_ids), From d76adf0693dc58974b7eddc1902b1465523bc495 Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Tue, 12 Nov 2024 04:02:29 +0900 Subject: [PATCH 6/7] Add feature flags to turn it off by default Signed-off-by: Zifei Tong --- tests/prefix_caching/test_prefix_caching.py | 3 ++- vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/cli_args.py | 5 +++++ vllm/entrypoints/openai/run_batch.py | 6 ++++++ vllm/entrypoints/openai/serving_chat.py | 9 ++++++--- 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 55d1ad90b4599..50723dbb610ac 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -72,7 +72,8 @@ def test_mixed_requests( block_size) * block_size else: expected_num_cached_tokens = 0 - assert req_outputs[i].num_cached_tokens == expected_num_cached_tokens + assert req_outputs[ + i].num_cached_tokens == expected_num_cached_tokens vllm_outputs = [ (output.prompt_token_ids + list(output.outputs[0].token_ids), diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 917b347ff1161..2302ab4752a4a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -525,6 +525,7 @@ def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.task == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a089985ac9758..7ace59d4beed9 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") return parser diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 0d016d949d22b..1b422a93263b2 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -78,6 +78,11 @@ def parse_args(): help="Port number for the Prometheus metrics server " "(only needed if enable-metrics is set).", ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") return parser.parse_args() @@ -217,6 +222,7 @@ async def main(args): prompt_adapters=None, request_logger=request_logger, chat_template=None, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.task == "generate" else None openai_serving_embedding = OpenAIServingEmbedding( engine, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 4b644570e20e4..74867d8de8843 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -49,7 +49,8 @@ def __init__(self, chat_template: Optional[str], return_tokens_as_token_ids: bool = False, enable_auto_tools: bool = False, - tool_parser: Optional[str] = None): + tool_parser: Optional[str] = None, + enable_prompt_tokens_details: bool = False): super().__init__(engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, @@ -80,6 +81,8 @@ def __init__(self, f"tool_parser:'{tool_parser}' which has not " "been registered") from e + self.enable_prompt_tokens_details = enable_prompt_tokens_details + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -536,7 +539,7 @@ async def chat_completion_stream_generator( completion_tokens=completion_tokens, total_tokens=num_prompt_tokens + completion_tokens) - if num_cached_tokens: + if self.enable_prompt_tokens_details and num_cached_tokens: final_usage.prompt_tokens_details = PromptTokenUsageInfo( cached_tokens=num_cached_tokens) @@ -710,7 +713,7 @@ async def chat_completion_full_generator( completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens) - if final_res.num_cached_tokens: + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: usage.prompt_tokens_details = PromptTokenUsageInfo( cached_tokens=final_res.num_cached_tokens) From d02fee31c9451d0a821b456cc62ef56313517564 Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Tue, 12 Nov 2024 04:20:44 +0900 Subject: [PATCH 7/7] address comment Signed-off-by: Zifei Tong --- vllm/sequence.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/sequence.py b/vllm/sequence.py index 4b697ebf2a687..1370cb5c4f9d2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -167,6 +167,7 @@ class SequenceData(msgspec.Struct, ...] = msgspec.field(default_factory=tuple) # The number of tokens that are computed (that run against the model). _num_computed_tokens: int = 0 + # The number of tokens with prefix cache hit. _num_cached_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL _cached_all_token_ids: List[int] = msgspec.field(default_factory=list)