Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions examples/experiments/offline-priority-prefix-caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# ruff: noqa: E501
# SPDX-License-Identifier: Apache-2.0
from vllm import LLM, SamplingParams


def main():
block_size = 16

llm = LLM(
model="facebook/opt-125m",
enforce_eager=True,
block_size=block_size,
# two slots for ongoing compute and two slots for free queue.
num_gpu_blocks_override=5,
)

x_tokens = {"prompt_token_ids": [101] * (block_size + 1)}
y_tokens = {"prompt_token_ids": [102] * (block_size + 1)}
a_tokens = {"prompt_token_ids": [103] * (block_size + 1)}
b_tokens = {"prompt_token_ids": [104] * (block_size + 1)}

sampling_params = SamplingParams(temperature=0.0, max_tokens=1)

print("Sending P1 requests...")
for tokens in [x_tokens, y_tokens]:
output = llm.generate([tokens],
sampling_params=sampling_params,
priority=[1])
assert output[0].num_cached_tokens == 0

# The KV cache should be [x_tokens: cached, y_tokens: cached]

print("Verifying cache hit...")
for tokens in [x_tokens, y_tokens]:
outputs = llm.generate([tokens],
sampling_params=sampling_params,
priority=[1])
assert (
outputs[0].num_cached_tokens == block_size
), f"P1 requests should cache {block_size} tokens, but got {outputs[0].num_cached_tokens}"

print("Cache hit verified.")

print("Sending P0 requests...")
for tokens in [a_tokens, b_tokens]:
outputs = llm.generate([tokens],
sampling_params=sampling_params,
priority=[0])
assert outputs[0].num_cached_tokens == 0

# The KV cache should be [x_tokens: evicted, y_tokens: cached, a_tokens: evicted, b_tokens: cached]

print("Now send request A and B again...")
for tokens in [a_tokens, b_tokens]:
outputs = llm.generate([tokens],
sampling_params=sampling_params,
priority=[0])
# A and B should trash each other's cache.
assert outputs[0].num_cached_tokens == 0

# The KV cache should be [x_tokens: evicted, y_tokens: cached, a_tokens: evicted, b_tokens: cached]

print("P1's cache should be [x_tokens: evicted, y_tokens: cached]")
outputs = llm.generate([x_tokens],
sampling_params=sampling_params,
priority=[1])
assert outputs[0].num_cached_tokens == 0

outputs = llm.generate([y_tokens],
sampling_params=sampling_params,
priority=[1])
assert outputs[0].num_cached_tokens == block_size


if __name__ == "__main__":
main()
138 changes: 138 additions & 0 deletions examples/experiments/online-priority-prefix-caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# ruff: noqa: E501
# SPDX-License-Identifier: Apache-2.0
from openai import OpenAI

# Start a vllm server with the following flags:
# vllm serve \
# facebook/opt-125m \
# --port 8001 \
# --enable-prompt-tokens-details \
# --block-size 16 \
# --num-gpu-blocks-override 5

openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8001/v1"

client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id


def main():
block_size = 16 # Should match the block_size in server config

# Define prompts with exact token length
# Using distinct integer tokens for easier tracking
# (convert to strings since the API expects string prompts)
x_prompt = " ".join([str(101)] * block_size)
y_prompt = " ".join([str(102)] * block_size)
a_prompt = " ".join([str(103)] * block_size)
b_prompt = " ".join([str(104)] * block_size)

print("Sending P1 requests...")
for prompt in [x_prompt, y_prompt]:
response = client.completions.create(model=model,
prompt=prompt,
max_tokens=1,
temperature=0.0,
extra_body={"priority": 1})
cached = 0
if hasattr(response.usage, 'prompt_tokens_details'
) and response.usage.prompt_tokens_details:
cached = response.usage.prompt_tokens_details.cached_tokens or 0

print(f"Cached tokens: {cached}")
assert cached == 0, "First request should have no cached tokens"

# The KV cache should be [x_prompt: cached, y_prompt: cached]

print("Verifying cache hit...")
for prompt in [x_prompt, y_prompt]:
response = client.completions.create(model=model,
prompt=prompt,
max_tokens=1,
temperature=0.0,
extra_body={"priority": 1})
cached = 0
if hasattr(response.usage, 'prompt_tokens_details'
) and response.usage.prompt_tokens_details:
cached = response.usage.prompt_tokens_details.cached_tokens or 0

print(f"Cached tokens: {cached}")
assert cached == block_size, f"P1 requests should cache {block_size} tokens, but got {cached}"

print("Cache hit verified.")

print("Sending P0 requests...")
for prompt in [a_prompt, b_prompt]:
response = client.completions.create(model=model,
prompt=prompt,
max_tokens=1,
temperature=0.0,
extra_body={"priority": 0})
cached = 0
if hasattr(response.usage, 'prompt_tokens_details'
) and response.usage.prompt_tokens_details:
cached = response.usage.prompt_tokens_details.cached_tokens or 0

print(f"Cached tokens: {cached}")
assert cached == 0, "First P0 request should have no cached tokens"

# The KV cache should be [x_prompt: evicted, y_prompt: cached, a_prompt: evicted, b_prompt: cached]

print("Now send request A and B again...")
for prompt in [a_prompt, b_prompt]:
response = client.completions.create(model=model,
prompt=prompt,
max_tokens=1,
temperature=0.0,
extra_body={"priority": 0})
cached = 0
if hasattr(response.usage, 'prompt_tokens_details'
) and response.usage.prompt_tokens_details:
cached = response.usage.prompt_tokens_details.cached_tokens or 0

print(f"Cached tokens: {cached}")
# A and B should trash each other's cache.
assert cached == 0, f"P0 requests should trash each other's cache, but got {cached} cached tokens"

# The KV cache should be [x_prompt: evicted, y_prompt: cached, a_prompt: evicted, b_prompt: cached]

print("P1's cache should be [x_prompt: evicted, y_prompt: cached]")
response = client.completions.create(model=model,
prompt=x_prompt,
max_tokens=1,
temperature=0.0,
extra_body={"priority": 1})
cached = 0
if hasattr(
response.usage,
'prompt_tokens_details') and response.usage.prompt_tokens_details:
cached = response.usage.prompt_tokens_details.cached_tokens or 0

print(f"X cached tokens: {cached}")
assert cached == 0, f"x_prompt should be evicted, but got {cached} cached tokens"

response = client.completions.create(model=model,
prompt=y_prompt,
max_tokens=1,
temperature=0.0,
extra_body={"priority": 1})
cached = 0
if hasattr(
response.usage,
'prompt_tokens_details') and response.usage.prompt_tokens_details:
cached = response.usage.prompt_tokens_details.cached_tokens or 0

print(f"Y cached tokens: {cached}")
assert cached == block_size, f"y_prompt should cache {block_size} tokens, but got {cached} cached tokens"

print("Test completed successfully!")


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions examples/offline_inference/basic/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="facebook/opt-125m", num_gpu_blocks_override=10)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
outputs = llm.generate(prompts, sampling_params, priority=[0, 1, 0, 0])
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
Expand Down
38 changes: 38 additions & 0 deletions tests/v1/core/test_prefix_caching_priority.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

from vllm.v1.core.block_pool import BlockPool


def test_free_blocks_priority():
# Create a BlockPool with 5 blocks and prefix caching enabled
bp = BlockPool(num_gpu_blocks=6, enable_caching=True)
# Initially, free list should contain all non-null blocks [1,2,3,4]
initial_free = bp.free_block_queue.get_all_free_blocks()
initial_ids = [blk.block_id for blk in initial_free]
assert initial_ids == [1, 2, 3, 4, 5]

# Allocate 2 blocks for request R0 (to simulate priority 0)
r0_blocks = bp.get_new_blocks(2)
# Allocate 2 blocks for request R1 (to simulate priority 1)
r1_blocks = bp.get_new_blocks(2)
# Remaining free blocks
remaining_ids = [
blk.block_id for blk in bp.free_block_queue.get_all_free_blocks()
]
assert remaining_ids == [5]

# Free R0 blocks (priority 0: evict before priority 1 blocks)
# Reverse within request so tail blocks freed first.
bp.free_blocks(reversed(r0_blocks), front=True)
# Free R1 blocks (priority 1: evict after priority 0 blocks)
bp.free_blocks(reversed(r1_blocks))

# Collect final free list
final_free = bp.free_block_queue.get_all_free_blocks()
final_ids = [blk.block_id for blk in final_free]

Check failure on line 33 in tests/v1/core/test_prefix_caching_priority.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

tests/v1/core/test_prefix_caching_priority.py:33:81: E501 Line too long (99 > 80)
# Expected order: R0 blocks at front (in reverse order), then remaining, then R1 blocks at tail
expected = remaining_ids + [
r0_blocks[1].block_id, r0_blocks[0].block_id
] + [r1_blocks[1].block_id, r1_blocks[0].block_id]
assert final_ids == expected
11 changes: 10 additions & 1 deletion tests/v1/core/test_scheduler_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,20 @@ def model() -> LLM:
enable_prefix_caching=True,
long_prefill_token_threshold=2,
max_num_batched_tokens=6,
max_num_seqs=3)
max_num_seqs=3,
block_size=16)


def test_concurrent_partial_prefill(model):
outputs = model.generate([PROMPT] * 3)
assert len(outputs) == 3
for output in outputs:
assert len(output.outputs) == 1


def test_prefix_cache_stats_is_recorded(model):
# 17 tokens will make sure first 16 tokens are cached in a block
input_tokens = {"prompt_token_ids": [101] * 17}
_ = model.generate([input_tokens])
outputs = model.generate([input_tokens])
assert outputs[0].num_cached_tokens != 0
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,7 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.runner_type == "generate" else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
Expand Down
27 changes: 27 additions & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo)
# yapf: enable
Expand All @@ -47,6 +48,7 @@ def __init__(
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
Expand All @@ -60,6 +62,7 @@ def __init__(
source = "model" if source == "auto" else source
logger.info("Using default completion sampling params from %s: %s",
source, self.default_sampling_params)
self.enable_prompt_tokens_details = enable_prompt_tokens_details

async def create_completion(
self,
Expand Down Expand Up @@ -260,6 +263,7 @@ async def completion_stream_generator(
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
num_cached_tokens = None # Add this to track cached tokens

stream_options = request.stream_options
if stream_options:
Expand All @@ -271,6 +275,11 @@ async def completion_stream_generator(

try:
async for prompt_idx, res in result_generator:
# Store cached tokens if available
if (self.enable_prompt_tokens_details
and res.num_cached_tokens is not None):
num_cached_tokens = res.num_cached_tokens

prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt
Expand Down Expand Up @@ -370,6 +379,13 @@ async def completion_stream_generator(
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)

# Add prompt tokens details if enabled
# and cached tokens are available
if (self.enable_prompt_tokens_details
and num_cached_tokens is not None):
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)

if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
Expand Down Expand Up @@ -404,8 +420,14 @@ def request_output_to_completion_response(
choices: list[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
num_cached_tokens = None # Store the number of cached tokens

for final_res in final_res_batch:
# Store cached tokens value if available
if (self.enable_prompt_tokens_details
and final_res.num_cached_tokens is not None):
num_cached_tokens = final_res.num_cached_tokens

prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
Expand Down Expand Up @@ -474,6 +496,11 @@ def request_output_to_completion_response(
total_tokens=num_prompt_tokens + num_generated_tokens,
)

# Add prompt tokens details if enabled and cached tokens are available
if self.enable_prompt_tokens_details and num_cached_tokens is not None:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)

request_metadata.final_usage_info = usage

return CompletionResponse(
Expand Down
Loading
Loading