diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index fd6564bbfe630..50723dbb610ac 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -27,6 +27,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("cached_position", [0, 1]) +@pytest.mark.parametrize("block_size", [16]) def test_mixed_requests( hf_runner, vllm_runner, @@ -36,11 +37,12 @@ def test_mixed_requests( dtype: str, max_tokens: int, cached_position: int, + block_size: int, monkeypatch, ) -> None: """ Test the case when some sequences have the prefix cache hit - and the others don't. The cached position determines where + and the others don't. The cached position determines where the sequence is at among the batch of prefills. """ override_backend_env_variable(monkeypatch, backend) @@ -53,12 +55,30 @@ def test_mixed_requests( model, dtype=dtype, enable_prefix_caching=True, + block_size=block_size, ) as vllm_model: # Run the first prompt so the cache is populated vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) # Run all the promopts - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + req_outputs = vllm_model.model.generate(example_prompts, greedy_params) + + # Verify number of cached tokens + for i in range(len(req_outputs)): + if i == cached_position: + expected_num_cached_tokens = ( + len(req_outputs[i].prompt_token_ids) // + block_size) * block_size + else: + expected_num_cached_tokens = 0 + assert req_outputs[ + i].num_cached_tokens == expected_num_cached_tokens + + vllm_outputs = [ + (output.prompt_token_ids + list(output.outputs[0].token_ids), + output.prompt + output.outputs[0].text) for output in req_outputs + ] check_outputs_equal( outputs_0_lst=hf_outputs, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3e4070a25cf90..6a24cdbc6a18f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -540,6 +540,7 @@ def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.task == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 74ea41344bece..eb08a89293370 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") return parser diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0e0bb66c057df..820aefd8800d9 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -99,10 +99,15 @@ class ModelList(OpenAIBaseModel): data: List[ModelCard] = Field(default_factory=list) +class PromptTokenUsageInfo(OpenAIBaseModel): + cached_tokens: Optional[int] = None + + class UsageInfo(OpenAIBaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 + prompt_tokens_details: Optional[PromptTokenUsageInfo] = None class RequestResponseMetadata(BaseModel): diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 0d016d949d22b..1b422a93263b2 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -78,6 +78,11 @@ def parse_args(): help="Port number for the Prometheus metrics server " "(only needed if enable-metrics is set).", ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") return parser.parse_args() @@ -217,6 +222,7 @@ async def main(args): prompt_adapters=None, request_logger=request_logger, chat_template=None, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.task == "generate" else None openai_serving_embedding = OpenAIServingEmbedding( engine, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9551b4f2091dd..74867d8de8843 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -18,8 +18,8 @@ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, - ToolCall, UsageInfo) + DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, + RequestResponseMetadata, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, OpenAIServing, @@ -49,7 +49,8 @@ def __init__(self, chat_template: Optional[str], return_tokens_as_token_ids: bool = False, enable_auto_tools: bool = False, - tool_parser: Optional[str] = None): + tool_parser: Optional[str] = None, + enable_prompt_tokens_details: bool = False): super().__init__(engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, @@ -80,6 +81,8 @@ def __init__(self, f"tool_parser:'{tool_parser}' which has not " "been registered") from e + self.enable_prompt_tokens_details = enable_prompt_tokens_details + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -252,6 +255,7 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices num_prompt_tokens = 0 + num_cached_tokens = None if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -305,6 +309,7 @@ async def chat_completion_stream_generator( # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). if first_iteration: + num_cached_tokens = res.num_cached_tokens # Send first response for each request.n (index) with # the role role = self.get_chat_request_role(request) @@ -530,11 +535,13 @@ async def chat_completion_stream_generator( # is sent, send the usage if include_usage: completion_tokens = sum(previous_num_tokens) - final_usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens, - ) + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + if self.enable_prompt_tokens_details and num_cached_tokens: + final_usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens) final_usage_chunk = ChatCompletionStreamResponse( id=request_id, @@ -702,11 +709,13 @@ async def chat_completion_full_generator( num_prompt_tokens += len(final_res.encoder_prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) request_metadata.final_usage_info = usage diff --git a/vllm/outputs.py b/vllm/outputs.py index abfdb7d328126..badf50d0602d6 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -83,10 +83,11 @@ class RequestOutput: finished: Whether the whole request is finished. metrics: Metrics associated with the request. lora_request: The LoRA request that was used to generate the output. - encoder_prompt: The encoder prompt string of the request; - None if decoder-only - encoder_prompt_token_ids: The token IDs of the encoder prompt; - None if decoder-only + encoder_prompt: The encoder prompt string of the request. + None if decoder-only. + encoder_prompt_token_ids: The token IDs of the encoder prompt. + None if decoder-only. + num_cached_tokens: The number of tokens with prefix cache hit. """ def __init__( @@ -101,6 +102,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, encoder_prompt: Optional[str] = None, encoder_prompt_token_ids: Optional[List[int]] = None, + num_cached_tokens: Optional[int] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -112,6 +114,7 @@ def __init__( self.lora_request = lora_request self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids + self.num_cached_tokens = num_cached_tokens @classmethod def new( @@ -192,6 +195,8 @@ def from_seq_group( outputs = [] include_prompt = True + # num_cached_tokens should be the same for all the sequences + num_cached_tokens = None for i, seq in enumerate(top_n_seqs): output_text = seq.get_output_text_to_return( text_buffer_length, delta) @@ -199,6 +204,7 @@ def from_seq_group( output_token_ids = seq.get_output_token_ids_to_return(delta) num_output_tokens = 1 if isinstance(output_token_ids, int) else len(output_token_ids) + num_cached_tokens = seq.data.get_num_cached_tokens() output_logprobs = seq.output_logprobs if include_logprobs else None @@ -272,7 +278,7 @@ def from_seq_group( init_args = (seq_group.request_id, prompt, prompt_token_ids, prompt_logprobs, outputs, finished, seq_group.metrics, seq_group.lora_request, encoder_prompt, - encoder_prompt_token_ids) + encoder_prompt_token_ids, num_cached_tokens) if use_cache: request_output = seq_group.cached_request_output @@ -293,7 +299,8 @@ def __repr__(self) -> str: f"outputs={self.outputs}, " f"finished={self.finished}, " f"metrics={self.metrics}, " - f"lora_request={self.lora_request})") + f"lora_request={self.lora_request}, " + f"num_cached_tokens={self.num_cached_tokens})") class EmbeddingRequestOutput: diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d7ddc7ec4447..1370cb5c4f9d2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -167,6 +167,8 @@ class SequenceData(msgspec.Struct, ...] = msgspec.field(default_factory=tuple) # The number of tokens that are computed (that run against the model). _num_computed_tokens: int = 0 + # The number of tokens with prefix cache hit. + _num_cached_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL _cached_all_token_ids: List[int] = msgspec.field(default_factory=list) @@ -323,6 +325,14 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): if self.get_num_uncomputed_tokens() == 0: self._stage = SequenceStage.DECODE + def get_num_cached_tokens(self) -> int: + """Return the number of tokens with prefix cache hit.""" + return self._num_cached_tokens + + def update_num_cached_tokens(self, num_cached_tokens: int): + """Update the number of tokens with prefix cache hit.""" + self._num_cached_tokens = num_cached_tokens + def reset_state_for_recompute(self) -> None: """Reset the number of computed tokens from this sequence. It is supposed to be called when a sequence needs to be started from @@ -379,7 +389,7 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. - + The sequence is constructed from the :data:`DecoderOnlyInputs` (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder) instance passed in through the :code:`inputs` constructor argument. @@ -906,7 +916,7 @@ class SequenceGroupMetadata( multi_modal_data: Multi modal data. mm_processor_kwargs: Multimodal input processor / mapper overrides. encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None + (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. cross_block_table: Optional cross-attention block table associated diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e1446192ce3d6..2da02f21f8342 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -542,6 +542,9 @@ def _compute_for_prefix_cache_hit( # this may be larger than the sequence length if chunked # prefill is enabled. prefix_cache_len = len(computed_block_nums) * self.block_size + seq_group_metadata.seq_data[inter_data.seq_ids[ + seq_idx]].update_num_cached_tokens(prefix_cache_len) + # The number of so far computed prompt tokens in this sequence. context_len = inter_data.context_lens[seq_idx] # The total number of prompt tokens in this sequence.