Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tests/v1/core/test_scheduler_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,20 @@ def model() -> LLM:
enable_prefix_caching=True,
long_prefill_token_threshold=2,
max_num_batched_tokens=6,
max_num_seqs=3)
max_num_seqs=3,
block_size=16)


def test_concurrent_partial_prefill(model):
outputs = model.generate([PROMPT] * 3)
assert len(outputs) == 3
for output in outputs:
assert len(output.outputs) == 1


def test_prefix_cache_stats_is_recorded(model):
# 17 tokens will make sure first 16 tokens are cached in a block
input_tokens = {"prompt_token_ids": [101] * 17}
_ = model.generate([input_tokens])
outputs = model.generate([input_tokens])
assert outputs[0].num_cached_tokens == 16
5 changes: 4 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,9 @@ def schedule(self) -> SchedulerOutput:
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens

# Count the number of prifix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
Expand Down Expand Up @@ -798,6 +800,7 @@ def update_from_output(
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))

else:
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ class EngineCoreOutput(
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None

# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0

@property
def finished(self) -> bool:
return self.finish_reason is not None
Expand Down
9 changes: 6 additions & 3 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def make_request_output(
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Optional[RequestOutput]:

finished = finish_reason is not None
Expand All @@ -169,14 +170,15 @@ def make_request_output(
return None

return self._new_request_output(request_id, outputs, finished,
kv_transfer_params)
kv_transfer_params, num_cached_tokens)

def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> RequestOutput:

if self.output_kind == RequestOutputKind.DELTA:
Expand All @@ -193,6 +195,7 @@ def _new_request_output(
outputs=outputs,
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=num_cached_tokens,
)

def _new_completion_output(
Expand Down Expand Up @@ -340,7 +343,7 @@ def process_outputs(
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params

num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False

# 2) Detokenize the token ids into text and perform stop checks.
Expand All @@ -356,7 +359,7 @@ def process_outputs(
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason,
kv_transfer_params):
kv_transfer_params, num_cached_tokens):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def __init__(
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)

# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more appropriate to initialize this to 0 here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chaunceyjiang I think it should be -1 to indicate that the variable has never been assigned.


@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None:
Expand Down