Skip to content

Commit b046cf7

Browse files
[Feature][V1]: suupports cached_tokens in response usage (#18149)
Co-authored-by: simon-mo <[email protected]>
1 parent 54af915 commit b046cf7

File tree

5 files changed

+27
-5
lines changed

5 files changed

+27
-5
lines changed

tests/v1/core/test_scheduler_e2e.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,20 @@ def model() -> LLM:
1919
enable_prefix_caching=True,
2020
long_prefill_token_threshold=2,
2121
max_num_batched_tokens=6,
22-
max_num_seqs=3)
22+
max_num_seqs=3,
23+
block_size=16)
2324

2425

2526
def test_concurrent_partial_prefill(model):
2627
outputs = model.generate([PROMPT] * 3)
2728
assert len(outputs) == 3
2829
for output in outputs:
2930
assert len(output.outputs) == 1
31+
32+
33+
def test_prefix_cache_stats_is_recorded(model):
34+
# 17 tokens will make sure first 16 tokens are cached in a block
35+
input_tokens = {"prompt_token_ids": [101] * 17}
36+
_ = model.generate([input_tokens])
37+
outputs = model.generate([input_tokens])
38+
assert outputs[0].num_cached_tokens == 16

vllm/v1/core/sched/scheduler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,9 @@ def schedule(self) -> SchedulerOutput:
457457
token_budget -= num_new_tokens
458458
request.status = RequestStatus.RUNNING
459459
request.num_computed_tokens = num_computed_tokens
460-
460+
# Count the number of prifix cached tokens.
461+
if request.num_cached_tokens < 0:
462+
request.num_cached_tokens = num_computed_tokens
461463
# Encoder-related.
462464
if encoder_inputs_to_schedule:
463465
scheduled_encoder_inputs[request.request_id] = (
@@ -798,6 +800,7 @@ def update_from_output(
798800
stop_reason=request.stop_reason,
799801
events=request.take_events(),
800802
kv_transfer_params=kv_transfer_params,
803+
num_cached_tokens=request.num_cached_tokens,
801804
))
802805

803806
else:

vllm/v1/engine/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ class EngineCoreOutput(
107107
events: Optional[list[EngineCoreEvent]] = None
108108
kv_transfer_params: Optional[dict[str, Any]] = None
109109

110+
# The number of tokens with prefix cache hits.
111+
num_cached_tokens: int = 0
112+
110113
@property
111114
def finished(self) -> bool:
112115
return self.finish_reason is not None

vllm/v1/engine/output_processor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def make_request_output(
147147
finish_reason: Optional[FinishReason],
148148
stop_reason: Union[int, str, None],
149149
kv_transfer_params: Optional[dict[str, Any]] = None,
150+
num_cached_tokens: int = 0,
150151
) -> Optional[RequestOutput]:
151152

152153
finished = finish_reason is not None
@@ -169,14 +170,15 @@ def make_request_output(
169170
return None
170171

171172
return self._new_request_output(request_id, outputs, finished,
172-
kv_transfer_params)
173+
kv_transfer_params, num_cached_tokens)
173174

174175
def _new_request_output(
175176
self,
176177
request_id: str,
177178
outputs: list[CompletionOutput],
178179
finished: bool,
179180
kv_transfer_params: Optional[dict[str, Any]] = None,
181+
num_cached_tokens: int = 0,
180182
) -> RequestOutput:
181183

182184
if self.output_kind == RequestOutputKind.DELTA:
@@ -193,6 +195,7 @@ def _new_request_output(
193195
outputs=outputs,
194196
finished=finished,
195197
kv_transfer_params=kv_transfer_params,
198+
num_cached_tokens=num_cached_tokens,
196199
)
197200

198201
def _new_completion_output(
@@ -340,7 +343,7 @@ def process_outputs(
340343
finish_reason = engine_core_output.finish_reason
341344
stop_reason = engine_core_output.stop_reason
342345
kv_transfer_params = engine_core_output.kv_transfer_params
343-
346+
num_cached_tokens = engine_core_output.num_cached_tokens
344347
req_state.is_prefilling = False
345348

346349
# 2) Detokenize the token ids into text and perform stop checks.
@@ -356,7 +359,7 @@ def process_outputs(
356359
# 4) Create and handle RequestOutput objects.
357360
if request_output := req_state.make_request_output(
358361
new_token_ids, finish_reason, stop_reason,
359-
kv_transfer_params):
362+
kv_transfer_params, num_cached_tokens):
360363
if req_state.queue is not None:
361364
# AsyncLLM: put into queue for handling by generate().
362365
req_state.queue.put(request_output)

vllm/v1/request.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def __init__(
7777
self.output_token_ids = ConstantList(self._output_token_ids)
7878
self.all_token_ids = ConstantList(self._all_token_ids)
7979

80+
# State
81+
# The number of tokens with prefix cache hits.
82+
self.num_cached_tokens = -1
83+
8084
@classmethod
8185
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
8286
if request.mm_inputs is not None:

0 commit comments

Comments
 (0)