From 6709c5a780fde7e9ca7a329ee1289dc04e531d51 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Tue, 29 Apr 2025 22:48:11 -0700 Subject: [PATCH 1/6] init proto Signed-off-by: simon-mo --- vllm/v1/core/block_pool.py | 11 +++++++++-- vllm/v1/core/kv_cache_utils.py | 18 ++++++++++++++++++ vllm/v1/engine/__init__.py | 3 +++ vllm/v1/engine/processor.py | 6 ++++-- vllm/v1/request.py | 5 +++++ 5 files changed, 39 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 74f3f7852c9a..1d72377c688a 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -224,19 +224,26 @@ def touch(self, blocks: list[KVCacheBlock]) -> None: self.free_block_queue.remove(block) block.incr_ref() - def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: + def free_blocks(self, + ordered_blocks: Iterable[KVCacheBlock], + front: bool = False) -> None: """Free a list of blocks. The blocks should be ordered by their eviction priority, where the first block will be evicted first. Args: ordered_blocks: A list of blocks to free ordered by their eviction priority. + front: If True, freed blocks are prepended to the free list (evicted sooner); + otherwise, appended to the tail (evicted later). """ for block in ordered_blocks: block.decr_ref() # null_block should not be added to the free list. if block.ref_cnt == 0 and block != self.null_block: - self.free_block_queue.append(block) + if front: + self.free_block_queue.appendleft(block) + else: + self.free_block_queue.append(block) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3026ecc1c968..aa9c079c2600 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -248,6 +248,24 @@ def append(self, block: KVCacheBlock) -> None: block.next_free_block = None self.num_free_blocks += 1 + + def appendleft(self, block: KVCacheBlock) -> None: + """Put a block at the front of the free list and increase num_free_blocks by 1. + + Args: + block: The block to prepend. + """ + if self.free_list_head is not None: + # Link the new block before the current head. + block.next_free_block = self.free_list_head + self.free_list_head.prev_free_block = block + self.free_list_head = block + else: + # The free list is empty. + self.free_list_head = self.free_list_tail = block + block.next_free_block = None + block.prev_free_block = None + self.num_free_blocks += 1 def get_all_free_blocks(self) -> list[KVCacheBlock]: """Get all free blocks in the free list. Mainly used for testing. diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 0474669610cd..7c30c801e24c 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -62,6 +62,9 @@ class EngineCoreRequest( # belong to, to cover a race condition where the request is sent before # a wave finished notification is received. current_wave: int = 0 + # Priority of the request: 0 (normal) or 1 (high). Higher priority requests + # have their prefix cache freed last (evicted later). + priority: int = 0 class EngineCoreEventType(enum.IntEnum): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 8ae5d01574c2..38a2d0bfc3d9 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -207,8 +207,9 @@ def process_inputs( # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) self._validate_params(params) - if priority != 0: - raise ValueError("V1 does not support priority yet.") + # Only support priority levels 0 (normal) and 1 (high). + if priority not in (0, 1): + raise ValueError("V1 only supports priority levels 0 or 1.") if trace_headers is not None: raise ValueError("V1 does not support tracing yet.") if prompt_adapter_request is not None: @@ -315,6 +316,7 @@ def process_inputs( eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, + priority=priority, ) def _validate_model_inputs(self, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 3b9b666f936a..1bde127b48d1 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -29,11 +29,14 @@ def __init__( arrival_time: float, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, + priority: int = 0, ) -> None: self.request_id = request_id self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id + # Priority (0 normal, 1 high) for prefix cache eviction ordering. + self.priority = priority self.lora_request = lora_request self.structured_output_request = structured_output_request @@ -77,6 +80,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": assert is_list_of(request.mm_inputs, MultiModalKwargs), ( "mm_inputs was not updated in EngineCore.add_request") + # Preserve priority from EngineCoreRequest for cache eviction ordering return cls( request_id=request.request_id, prompt_token_ids=request.prompt_token_ids, @@ -89,6 +93,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), + priority=getattr(request, "priority", 0), ) def append_output_token_ids( From 9eaf8d959fa2e9027c99004b0022d694a8d3cfaa Mon Sep 17 00:00:00 2001 From: simon-mo Date: Tue, 29 Apr 2025 22:48:20 -0700 Subject: [PATCH 2/6] init proto Signed-off-by: simon-mo --- tests/v1/core/test_prefix_caching_priority.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/v1/core/test_prefix_caching_priority.py diff --git a/tests/v1/core/test_prefix_caching_priority.py b/tests/v1/core/test_prefix_caching_priority.py new file mode 100644 index 000000000000..518f0a81b1ac --- /dev/null +++ b/tests/v1/core/test_prefix_caching_priority.py @@ -0,0 +1,38 @@ +import pytest + +from vllm.v1.core.block_pool import BlockPool + + +def test_free_blocks_priority(): + # Create a BlockPool with 5 blocks and prefix caching enabled + bp = BlockPool(num_gpu_blocks=6, enable_caching=True) + # Initially, free list should contain all non-null blocks [1,2,3,4] + initial_free = bp.free_block_queue.get_all_free_blocks() + initial_ids = [blk.block_id for blk in initial_free] + assert initial_ids == [1, 2, 3, 4, 5] + + # Allocate 2 blocks for request R0 (to simulate priority 0) + r0_blocks = bp.get_new_blocks(2) + # Allocate 2 blocks for request R1 (to simulate priority 1) + r1_blocks = bp.get_new_blocks(2) + # Remaining free blocks + remaining_ids = [ + blk.block_id for blk in bp.free_block_queue.get_all_free_blocks() + ] + assert remaining_ids == [5] + + # Free R0 blocks (priority 0: evict before priority 1 blocks) + # Reverse within request so tail blocks freed first. + bp.free_blocks(reversed(r0_blocks), front=True) + # Free R1 blocks (priority 1: evict after priority 0 blocks) + bp.free_blocks(reversed(r1_blocks)) + + # Collect final free list + final_free = bp.free_block_queue.get_all_free_blocks() + final_ids = [blk.block_id for blk in final_free] + + # Expected order: R0 blocks at front (in reverse order), then remaining, then R1 blocks at tail + expected = remaining_ids + [ + r0_blocks[1].block_id, r0_blocks[0].block_id + ] + [r1_blocks[1].block_id, r1_blocks[0].block_id] + assert final_ids == expected From c952dbcde91c7a2a4823e2c7d2e1c68d4b745522 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Tue, 29 Apr 2025 23:08:36 -0700 Subject: [PATCH 3/6] log and verified Signed-off-by: simon-mo --- examples/offline_inference/basic/basic.py | 4 +- vllm/v1/core/block_pool.py | 20 +++++- vllm/v1/core/kv_cache_manager.py | 3 +- vllm/v1/core/kv_cache_utils.py | 88 ++++++++++++++++------- 4 files changed, 85 insertions(+), 30 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index ae5ae7cb4834..fc85eae2a7fd 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -15,11 +15,11 @@ def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m") + llm = LLM(model="facebook/opt-125m", num_gpu_blocks_override=10) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) + outputs = llm.generate(prompts, sampling_params, priority=[0, 1, 0, 0]) # Print the outputs. print("\nGenerated Outputs:\n" + "-" * 60) for output in outputs: diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 1d72377c688a..95de8aa1c4f5 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -233,17 +233,33 @@ def free_blocks(self, Args: ordered_blocks: A list of blocks to free ordered by their eviction priority. - front: If True, freed blocks are prepended to the free list (evicted sooner); + front: If True, freed blocks are "prepended"to the free list (evicted sooner); + but still after the truly free blocks. otherwise, appended to the tail (evicted later). """ for block in ordered_blocks: + block_id = block.block_id block.decr_ref() # null_block should not be added to the free list. if block.ref_cnt == 0 and block != self.null_block: if front: - self.free_block_queue.appendleft(block) + logger.debug( + f"Freeing block {block_id} with P0 (front=True)") + # Use append_priority_0 for low priority (evict sooner) + self.free_block_queue.append_priority_0(block) else: + logger.debug( + f"Freeing block {block_id} with P1 (front=False)") + # Use append for high priority (evict later) self.free_block_queue.append(block) + # Log queue state after adding + current_queue = [ + b.block_id + for b in self.free_block_queue.get_all_free_blocks() + ] + logger.debug( + f"Free queue state after freeing {block_id}: {current_queue}" + ) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 0830d8433d89..87b13e4633f6 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -309,7 +309,8 @@ def free(self, request: Request) -> None: # freed first. ordered_blocks = reversed(blocks) - self.block_pool.free_blocks(ordered_blocks) + self.block_pool.free_blocks(ordered_blocks, + front=(request.priority == 0)) self.num_cached_block.pop(request.request_id, None) def reset_prefix_cache(self) -> bool: diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index aa9c079c2600..ec8683ecf57b 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -42,8 +42,10 @@ class BlockHashType(NamedTuple): # variable if set such that processes can share the seed if needed. # This aligns with the behavior of Python's hash() function, which also uses # a random seed if PYTHONHASHSEED is not set. -NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv( - 'PYTHONHASHSEED') is None else sha256(os.getenv('PYTHONHASHSEED')) +NONE_HASH = int.from_bytes( + os.urandom(32), + byteorder="big") if os.getenv('PYTHONHASHSEED') is None else sha256( + os.getenv('PYTHONHASHSEED')) class PrefixCachingMetrics: @@ -184,13 +186,20 @@ def __init__(self, blocks: list[KVCacheBlock]) -> None: self.num_free_blocks = len(blocks) # Initialize the doubly linked list of free blocks. - self.free_list_head: Optional[KVCacheBlock] = blocks[0] - self.free_list_tail: Optional[KVCacheBlock] = blocks[-1] - for i in range(self.num_free_blocks): - if i > 0: - blocks[i].prev_free_block = blocks[i - 1] - if i < self.num_free_blocks - 1: - blocks[i].next_free_block = blocks[i + 1] + self.free_list_head: Optional[ + KVCacheBlock] = blocks[0] if blocks else None + self.free_list_tail: Optional[ + KVCacheBlock] = blocks[-1] if blocks else None + # Marker for the start of the priority 1 blocks + self.priority_1_head: Optional[KVCacheBlock] = None + if blocks: + for i in range(self.num_free_blocks): + if i > 0: + blocks[i].prev_free_block = blocks[i - 1] + if i < self.num_free_blocks - 1: + blocks[i].next_free_block = blocks[i + 1] + else: + self.free_list_head = self.free_list_tail = None def popleft(self) -> KVCacheBlock: """Pop the first free block and reduce num_free_blocks by 1. @@ -202,6 +211,9 @@ def popleft(self) -> KVCacheBlock: raise ValueError("No free blocks available") block = self.free_list_head + # Update priority_1_head if the popped block was the P1 head + if block == self.priority_1_head: + self.priority_1_head = block.next_free_block self.remove(block) return block @@ -211,6 +223,10 @@ def remove(self, block: KVCacheBlock) -> None: Args: block: The block to remove. """ + # Update priority_1_head if the removed block was the P1 head + if block == self.priority_1_head: + self.priority_1_head = block.next_free_block + if block.prev_free_block is not None: # Link the previous block to the next block. block.prev_free_block.next_free_block = block.next_free_block @@ -230,12 +246,14 @@ def remove(self, block: KVCacheBlock) -> None: self.num_free_blocks -= 1 def append(self, block: KVCacheBlock) -> None: - """Put a block back into the free list and increase - num_free_blocks by 1. + """Put a block back into the tail of free list (Priority 1 behavior) + and increase num_free_blocks by 1. Args: block: The block to append. """ + logger.debug(f"Appending P1 block {block.block_id} to tail.") + # Standard append logic if self.free_list_tail is not None: # Link the last block to the new block. self.free_list_tail.next_free_block = block @@ -245,26 +263,46 @@ def append(self, block: KVCacheBlock) -> None: # The free list is empty. assert self.free_list_head is None self.free_list_head = self.free_list_tail = block - block.next_free_block = None + + # If this is the first P1 block, mark it. + if self.priority_1_head is None: + self.priority_1_head = block + self.num_free_blocks += 1 - - def appendleft(self, block: KVCacheBlock) -> None: - """Put a block at the front of the free list and increase num_free_blocks by 1. + + def append_priority_0(self, block: KVCacheBlock) -> None: + """Put a block back into the free list before Priority 1 blocks + (Priority 0 behavior) and increase num_free_blocks by 1. Args: - block: The block to prepend. + block: The block to append. """ - if self.free_list_head is not None: - # Link the new block before the current head. - block.next_free_block = self.free_list_head - self.free_list_head.prev_free_block = block - self.free_list_head = block - else: - # The free list is empty. - self.free_list_head = self.free_list_tail = block + logger.debug( + f"Appending P0 block {block.block_id} before P1 head (current P1 head: {self.priority_1_head.block_id if self.priority_1_head else 'None'})." + ) + if self.priority_1_head is None: + # No P1 blocks yet, append to the absolute tail like a P1 block, + # but DO NOT mark it as the P1 head. + if self.free_list_tail is not None: + self.free_list_tail.next_free_block = block + block.prev_free_block = self.free_list_tail + self.free_list_tail = block + else: + self.free_list_head = self.free_list_tail = block block.next_free_block = None - block.prev_free_block = None + else: + # Insert block just before priority_1_head + prev_block = self.priority_1_head.prev_free_block + block.next_free_block = self.priority_1_head + block.prev_free_block = prev_block + self.priority_1_head.prev_free_block = block + if prev_block is not None: + prev_block.next_free_block = block + else: + # The priority_1_head was the head, so block becomes the new head + self.free_list_head = block + self.num_free_blocks += 1 def get_all_free_blocks(self) -> list[KVCacheBlock]: From 51df30f6c2a56c15c11378dea0128b76b21b573e Mon Sep 17 00:00:00 2001 From: simon-mo Date: Wed, 30 Apr 2025 22:44:57 -0700 Subject: [PATCH 4/6] [V1] Add num_cached_tokens stats for request output Signed-off-by: simon-mo --- tests/v1/core/test_scheduler_e2e.py | 11 ++++++++++- vllm/v1/core/sched/scheduler.py | 10 +++++++--- vllm/v1/engine/__init__.py | 3 +++ vllm/v1/engine/output_processor.py | 9 +++++++-- vllm/v1/request.py | 1 + 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/tests/v1/core/test_scheduler_e2e.py b/tests/v1/core/test_scheduler_e2e.py index 0a79424a30b7..53525df51f72 100644 --- a/tests/v1/core/test_scheduler_e2e.py +++ b/tests/v1/core/test_scheduler_e2e.py @@ -19,7 +19,8 @@ def model() -> LLM: enable_prefix_caching=True, long_prefill_token_threshold=2, max_num_batched_tokens=6, - max_num_seqs=3) + max_num_seqs=3, + block_size=16) def test_concurrent_partial_prefill(model): @@ -27,3 +28,11 @@ def test_concurrent_partial_prefill(model): assert len(outputs) == 3 for output in outputs: assert len(output.outputs) == 1 + + +def test_prefix_cache_stats_is_recorded(model): + # 17 tokens will make sure first 16 tokens are cached in a block + input_tokens = {"prompt_token_ids": [101] * 17} + _ = model.generate([input_tokens]) + outputs = model.generate([input_tokens]) + assert outputs[0].num_cached_tokens != 0 diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 21711c9292f9..5c69424fdc1b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -334,6 +334,9 @@ def schedule(self) -> SchedulerOutput: # Total computed tokens (local + external). num_computed_tokens += num_external_tokens + # Update the statistic + request.num_cached_tokens = num_computed_tokens + # Number of tokens to be scheduled. # We use `request.num_tokens` instead of # `request.num_prompt_tokens` to consider the resumed requests, @@ -595,8 +598,8 @@ def _try_schedule_encoder_inputs( # only cover part of the mm input, roll back to before the mm item. if (self.scheduler_config.disable_chunked_mm_input and num_computed_tokens < start_pos - and (num_computed_tokens + num_new_tokens) - < (start_pos + num_encoder_tokens)): + and (num_computed_tokens + num_new_tokens) < + (start_pos + num_encoder_tokens)): num_new_tokens = start_pos - num_computed_tokens break @@ -729,7 +732,8 @@ def update_from_output( new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, stop_reason=request.stop_reason, - events=request.take_events())) + events=request.take_events(), + num_cached_tokens=request.num_cached_tokens)) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 0474669610cd..44938979b931 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -105,6 +105,8 @@ class EngineCoreOutput( stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None + num_cached_tokens: int = 0 + @property def finished(self) -> bool: return self.finish_reason is not None @@ -137,6 +139,7 @@ class EngineCoreOutputs( outputs: list[EngineCoreOutput] = [] scheduler_stats: Optional[SchedulerStats] = None timestamp: float = 0.0 + num_cached_tokens: int = 0 utility_output: Optional[UtilityOutput] = None finished_requests: Optional[set[str]] = None diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index f76c44cb8bca..95630234f823 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -146,6 +146,7 @@ def make_request_output( new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + num_cached_tokens: int, ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -167,13 +168,15 @@ def make_request_output( if not outputs: return None - return self._new_request_output(request_id, outputs, finished) + return self._new_request_output(request_id, outputs, finished, + num_cached_tokens) def _new_request_output( self, request_id: str, outputs: list[CompletionOutput], finished: bool, + num_cached_tokens: int, ) -> RequestOutput: if self.output_kind == RequestOutputKind.DELTA: @@ -189,6 +192,7 @@ def _new_request_output( prompt_logprobs=prompt_logprobs, outputs=outputs, finished=finished, + num_cached_tokens=num_cached_tokens, ) def _new_completion_output( @@ -352,7 +356,8 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, finish_reason, stop_reason, + engine_core_output.num_cached_tokens): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 3b9b666f936a..2d1dc7b0814a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -51,6 +51,7 @@ def __init__( self._all_token_ids: list[int] = self.prompt_token_ids.copy() self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 + self.num_cached_tokens = 0 # Multi-modal related self.mm_positions = multi_modal_placeholders or [] From 38df11f3b92798c70cae145bafdf97815b785740 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Wed, 30 Apr 2025 22:53:39 -0700 Subject: [PATCH 5/6] lint Signed-off-by: simon-mo --- vllm/v1/core/sched/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1a6bf81861bb..cada16e4177f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -613,8 +613,8 @@ def _try_schedule_encoder_inputs( # only cover part of the mm input, roll back to before the mm item. if (self.scheduler_config.disable_chunked_mm_input and num_computed_tokens < start_pos - and (num_computed_tokens + num_new_tokens) < - (start_pos + num_encoder_tokens)): + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens)): num_new_tokens = start_pos - num_computed_tokens break From 1358e92174c9f05583c761716642511064f2bde1 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Wed, 30 Apr 2025 23:42:56 -0700 Subject: [PATCH 6/6] working e2e Signed-off-by: simon-mo --- .../offline-priority-prefix-caching.py | 76 ++++++++++ .../online-priority-prefix-caching.py | 138 ++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/serving_completion.py | 27 ++++ vllm/v1/core/block_pool.py | 24 +-- 5 files changed, 254 insertions(+), 12 deletions(-) create mode 100644 examples/experiments/offline-priority-prefix-caching.py create mode 100644 examples/experiments/online-priority-prefix-caching.py diff --git a/examples/experiments/offline-priority-prefix-caching.py b/examples/experiments/offline-priority-prefix-caching.py new file mode 100644 index 000000000000..5daaac011a3b --- /dev/null +++ b/examples/experiments/offline-priority-prefix-caching.py @@ -0,0 +1,76 @@ +# ruff: noqa: E501 +# SPDX-License-Identifier: Apache-2.0 +from vllm import LLM, SamplingParams + + +def main(): + block_size = 16 + + llm = LLM( + model="facebook/opt-125m", + enforce_eager=True, + block_size=block_size, + # two slots for ongoing compute and two slots for free queue. + num_gpu_blocks_override=5, + ) + + x_tokens = {"prompt_token_ids": [101] * (block_size + 1)} + y_tokens = {"prompt_token_ids": [102] * (block_size + 1)} + a_tokens = {"prompt_token_ids": [103] * (block_size + 1)} + b_tokens = {"prompt_token_ids": [104] * (block_size + 1)} + + sampling_params = SamplingParams(temperature=0.0, max_tokens=1) + + print("Sending P1 requests...") + for tokens in [x_tokens, y_tokens]: + output = llm.generate([tokens], + sampling_params=sampling_params, + priority=[1]) + assert output[0].num_cached_tokens == 0 + + # The KV cache should be [x_tokens: cached, y_tokens: cached] + + print("Verifying cache hit...") + for tokens in [x_tokens, y_tokens]: + outputs = llm.generate([tokens], + sampling_params=sampling_params, + priority=[1]) + assert ( + outputs[0].num_cached_tokens == block_size + ), f"P1 requests should cache {block_size} tokens, but got {outputs[0].num_cached_tokens}" + + print("Cache hit verified.") + + print("Sending P0 requests...") + for tokens in [a_tokens, b_tokens]: + outputs = llm.generate([tokens], + sampling_params=sampling_params, + priority=[0]) + assert outputs[0].num_cached_tokens == 0 + + # The KV cache should be [x_tokens: evicted, y_tokens: cached, a_tokens: evicted, b_tokens: cached] + + print("Now send request A and B again...") + for tokens in [a_tokens, b_tokens]: + outputs = llm.generate([tokens], + sampling_params=sampling_params, + priority=[0]) + # A and B should trash each other's cache. + assert outputs[0].num_cached_tokens == 0 + + # The KV cache should be [x_tokens: evicted, y_tokens: cached, a_tokens: evicted, b_tokens: cached] + + print("P1's cache should be [x_tokens: evicted, y_tokens: cached]") + outputs = llm.generate([x_tokens], + sampling_params=sampling_params, + priority=[1]) + assert outputs[0].num_cached_tokens == 0 + + outputs = llm.generate([y_tokens], + sampling_params=sampling_params, + priority=[1]) + assert outputs[0].num_cached_tokens == block_size + + +if __name__ == "__main__": + main() diff --git a/examples/experiments/online-priority-prefix-caching.py b/examples/experiments/online-priority-prefix-caching.py new file mode 100644 index 000000000000..71b9f73cf425 --- /dev/null +++ b/examples/experiments/online-priority-prefix-caching.py @@ -0,0 +1,138 @@ +# ruff: noqa: E501 +# SPDX-License-Identifier: Apache-2.0 +from openai import OpenAI + +# Start a vllm server with the following flags: +# vllm serve \ +# facebook/opt-125m \ +# --port 8001 \ +# --enable-prompt-tokens-details \ +# --block-size 16 \ +# --num-gpu-blocks-override 5 + +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8001/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) +models = client.models.list() +model = models.data[0].id + + +def main(): + block_size = 16 # Should match the block_size in server config + + # Define prompts with exact token length + # Using distinct integer tokens for easier tracking + # (convert to strings since the API expects string prompts) + x_prompt = " ".join([str(101)] * block_size) + y_prompt = " ".join([str(102)] * block_size) + a_prompt = " ".join([str(103)] * block_size) + b_prompt = " ".join([str(104)] * block_size) + + print("Sending P1 requests...") + for prompt in [x_prompt, y_prompt]: + response = client.completions.create(model=model, + prompt=prompt, + max_tokens=1, + temperature=0.0, + extra_body={"priority": 1}) + cached = 0 + if hasattr(response.usage, 'prompt_tokens_details' + ) and response.usage.prompt_tokens_details: + cached = response.usage.prompt_tokens_details.cached_tokens or 0 + + print(f"Cached tokens: {cached}") + assert cached == 0, "First request should have no cached tokens" + + # The KV cache should be [x_prompt: cached, y_prompt: cached] + + print("Verifying cache hit...") + for prompt in [x_prompt, y_prompt]: + response = client.completions.create(model=model, + prompt=prompt, + max_tokens=1, + temperature=0.0, + extra_body={"priority": 1}) + cached = 0 + if hasattr(response.usage, 'prompt_tokens_details' + ) and response.usage.prompt_tokens_details: + cached = response.usage.prompt_tokens_details.cached_tokens or 0 + + print(f"Cached tokens: {cached}") + assert cached == block_size, f"P1 requests should cache {block_size} tokens, but got {cached}" + + print("Cache hit verified.") + + print("Sending P0 requests...") + for prompt in [a_prompt, b_prompt]: + response = client.completions.create(model=model, + prompt=prompt, + max_tokens=1, + temperature=0.0, + extra_body={"priority": 0}) + cached = 0 + if hasattr(response.usage, 'prompt_tokens_details' + ) and response.usage.prompt_tokens_details: + cached = response.usage.prompt_tokens_details.cached_tokens or 0 + + print(f"Cached tokens: {cached}") + assert cached == 0, "First P0 request should have no cached tokens" + + # The KV cache should be [x_prompt: evicted, y_prompt: cached, a_prompt: evicted, b_prompt: cached] + + print("Now send request A and B again...") + for prompt in [a_prompt, b_prompt]: + response = client.completions.create(model=model, + prompt=prompt, + max_tokens=1, + temperature=0.0, + extra_body={"priority": 0}) + cached = 0 + if hasattr(response.usage, 'prompt_tokens_details' + ) and response.usage.prompt_tokens_details: + cached = response.usage.prompt_tokens_details.cached_tokens or 0 + + print(f"Cached tokens: {cached}") + # A and B should trash each other's cache. + assert cached == 0, f"P0 requests should trash each other's cache, but got {cached} cached tokens" + + # The KV cache should be [x_prompt: evicted, y_prompt: cached, a_prompt: evicted, b_prompt: cached] + + print("P1's cache should be [x_prompt: evicted, y_prompt: cached]") + response = client.completions.create(model=model, + prompt=x_prompt, + max_tokens=1, + temperature=0.0, + extra_body={"priority": 1}) + cached = 0 + if hasattr( + response.usage, + 'prompt_tokens_details') and response.usage.prompt_tokens_details: + cached = response.usage.prompt_tokens_details.cached_tokens or 0 + + print(f"X cached tokens: {cached}") + assert cached == 0, f"x_prompt should be evicted, but got {cached} cached tokens" + + response = client.completions.create(model=model, + prompt=y_prompt, + max_tokens=1, + temperature=0.0, + extra_body={"priority": 1}) + cached = 0 + if hasattr( + response.usage, + 'prompt_tokens_details') and response.usage.prompt_tokens_details: + cached = response.usage.prompt_tokens_details.cached_tokens or 0 + + print(f"Y cached tokens: {cached}") + assert cached == block_size, f"y_prompt should cache {block_size} tokens, but got {cached} cached tokens" + + print("Test completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 136819580897..4f88941bbbb6 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -977,6 +977,7 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.runner_type == "generate" else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1067f35ce240..401bf07d73ab 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -21,6 +21,7 @@ CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, + PromptTokenUsageInfo, RequestResponseMetadata, UsageInfo) # yapf: enable @@ -47,6 +48,7 @@ def __init__( *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + enable_prompt_tokens_details: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, @@ -60,6 +62,7 @@ def __init__( source = "model" if source == "auto" else source logger.info("Using default completion sampling params from %s: %s", source, self.default_sampling_params) + self.enable_prompt_tokens_details = enable_prompt_tokens_details async def create_completion( self, @@ -260,6 +263,7 @@ async def completion_stream_generator( previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts num_prompt_tokens = [0] * num_prompts + num_cached_tokens = None # Add this to track cached tokens stream_options = request.stream_options if stream_options: @@ -271,6 +275,11 @@ async def completion_stream_generator( try: async for prompt_idx, res in result_generator: + # Store cached tokens if available + if (self.enable_prompt_tokens_details + and res.num_cached_tokens is not None): + num_cached_tokens = res.num_cached_tokens + prompt_token_ids = res.prompt_token_ids prompt_logprobs = res.prompt_logprobs prompt_text = res.prompt @@ -370,6 +379,13 @@ async def completion_stream_generator( completion_tokens=total_completion_tokens, total_tokens=total_prompt_tokens + total_completion_tokens) + # Add prompt tokens details if enabled + # and cached tokens are available + if (self.enable_prompt_tokens_details + and num_cached_tokens is not None): + final_usage_info.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens) + if include_usage: final_usage_chunk = CompletionStreamResponse( id=request_id, @@ -404,8 +420,14 @@ def request_output_to_completion_response( choices: list[CompletionResponseChoice] = [] num_prompt_tokens = 0 num_generated_tokens = 0 + num_cached_tokens = None # Store the number of cached tokens for final_res in final_res_batch: + # Store cached tokens value if available + if (self.enable_prompt_tokens_details + and final_res.num_cached_tokens is not None): + num_cached_tokens = final_res.num_cached_tokens + prompt_token_ids = final_res.prompt_token_ids assert prompt_token_ids is not None prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs) @@ -474,6 +496,11 @@ def request_output_to_completion_response( total_tokens=num_prompt_tokens + num_generated_tokens, ) + # Add prompt tokens details if enabled and cached tokens are available + if self.enable_prompt_tokens_details and num_cached_tokens is not None: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens) + request_metadata.final_usage_info = usage return CompletionResponse( diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 4268106e2bbd..399ed1f1963c 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -270,28 +270,28 @@ def free_blocks(self, otherwise, appended to the tail (evicted later). """ for block in ordered_blocks: - block_id = block.block_id + # block_id = block.block_id block.decr_ref() # null_block should not be added to the free list. if block.ref_cnt == 0 and block != self.null_block: if front: - logger.info("Freeing block %s with P0 (front=True)", - block_id) + # logger.info("Freeing block %s with P0 (front=True)", + # block_id) # Use append_priority_0 for low priority (evict sooner) self.free_block_queue.append_priority_0(block) else: - logger.info("Freeing block %s with P1 (front=False)", - block_id) + # logger.info("Freeing block %s with P1 (front=False)", + # block_id) # Use append for high priority (evict later) self.free_block_queue.append(block) # Log queue state after adding - current_queue = [ - (b.block_id, b.block_hash.hash_value) - for b in self.free_block_queue.get_all_free_blocks() - if b.block_hash is not None - ] - logger.info("Free queue state after freeing " - "%s: %s", block_id, current_queue) + # current_queue = [ + # (b.block_id, b.block_hash.hash_value) + # for b in self.free_block_queue.get_all_free_blocks() + # if b.block_hash is not None + # ] + # logger.info("Free queue state after freeing " + # "%s: %s", block_id, current_queue) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF