From 1cd01d4dede46be7c65fdd9e4dcb6eb83c57154f Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Wed, 14 May 2025 15:12:34 +0000 Subject: [PATCH] [Feature][v1]: cached_tokens in Chat Completion Response usage Signed-off-by: chaunceyjiang Co-authored-by: simon-mo --- tests/v1/core/test_scheduler_e2e.py | 11 ++++++++++- vllm/v1/core/sched/scheduler.py | 5 ++++- vllm/v1/engine/__init__.py | 3 +++ vllm/v1/engine/output_processor.py | 9 ++++++--- vllm/v1/request.py | 4 ++++ 5 files changed, 27 insertions(+), 5 deletions(-) diff --git a/tests/v1/core/test_scheduler_e2e.py b/tests/v1/core/test_scheduler_e2e.py index 0a79424a30b7..511d57d405ba 100644 --- a/tests/v1/core/test_scheduler_e2e.py +++ b/tests/v1/core/test_scheduler_e2e.py @@ -19,7 +19,8 @@ def model() -> LLM: enable_prefix_caching=True, long_prefill_token_threshold=2, max_num_batched_tokens=6, - max_num_seqs=3) + max_num_seqs=3, + block_size=16) def test_concurrent_partial_prefill(model): @@ -27,3 +28,11 @@ def test_concurrent_partial_prefill(model): assert len(outputs) == 3 for output in outputs: assert len(output.outputs) == 1 + + +def test_prefix_cache_stats_is_recorded(model): + # 17 tokens will make sure first 16 tokens are cached in a block + input_tokens = {"prompt_token_ids": [101] * 17} + _ = model.generate([input_tokens]) + outputs = model.generate([input_tokens]) + assert outputs[0].num_cached_tokens == 16 diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 2152409019b9..c873ced343bf 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -457,7 +457,9 @@ def schedule(self) -> SchedulerOutput: token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - + # Count the number of prifix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( @@ -798,6 +800,7 @@ def update_from_output( stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, + num_cached_tokens=request.num_cached_tokens, )) else: diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 122a5a72cc36..41db99beaad5 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -107,6 +107,9 @@ class EngineCoreOutput( events: Optional[list[EngineCoreEvent]] = None kv_transfer_params: Optional[dict[str, Any]] = None + # The number of tokens with prefix cache hits. + num_cached_tokens: int = 0 + @property def finished(self) -> bool: return self.finish_reason is not None diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index a7a9b0e4a161..293c291b4341 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -147,6 +147,7 @@ def make_request_output( finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], kv_transfer_params: Optional[dict[str, Any]] = None, + num_cached_tokens: int = 0, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -169,7 +170,7 @@ def make_request_output( return None return self._new_request_output(request_id, outputs, finished, - kv_transfer_params) + kv_transfer_params, num_cached_tokens) def _new_request_output( self, @@ -177,6 +178,7 @@ def _new_request_output( outputs: list[CompletionOutput], finished: bool, kv_transfer_params: Optional[dict[str, Any]] = None, + num_cached_tokens: int = 0, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -193,6 +195,7 @@ def _new_request_output( outputs=outputs, finished=finished, kv_transfer_params=kv_transfer_params, + num_cached_tokens=num_cached_tokens, ) def _new_completion_output( @@ -340,7 +343,7 @@ def process_outputs( finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason kv_transfer_params = engine_core_output.kv_transfer_params - + num_cached_tokens = engine_core_output.num_cached_tokens req_state.is_prefilling = False # 2) Detokenize the token ids into text and perform stop checks. @@ -356,7 +359,7 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( new_token_ids, finish_reason, stop_reason, - kv_transfer_params): + kv_transfer_params, num_cached_tokens): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index d1cdd2c52750..b4c84507532a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -77,6 +77,10 @@ def __init__( self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) + # State + # The number of tokens with prefix cache hits. + self.num_cached_tokens = -1 + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": if request.mm_inputs is not None: