From 67b75ee23782002ce4f12fd2d0e9854f5c66639e Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Wed, 23 Apr 2025 08:11:28 +0000 Subject: [PATCH 01/18] Prevent side-channel attacks via cache salting Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- docs/source/design/v1/prefix_caching.md | 18 ++- tests/entrypoints/openai/test_serving_chat.py | 20 +++ tests/v1/core/test_kv_cache_utils.py | 43 ++++- tests/v1/core/test_prefix_caching.py | 64 +++++++- vllm/entrypoints/openai/protocol.py | 9 ++ vllm/entrypoints/openai/serving_engine.py | 3 + vllm/inputs/data.py | 10 +- vllm/inputs/preprocess.py | 152 +++++++++--------- vllm/multimodal/inputs.py | 5 + vllm/multimodal/processing.py | 3 +- vllm/v1/core/kv_cache_utils.py | 13 +- vllm/v1/engine/__init__.py | 1 + vllm/v1/engine/processor.py | 1 + vllm/v1/request.py | 3 + 14 files changed, 263 insertions(+), 82 deletions(-) diff --git a/docs/source/design/v1/prefix_caching.md b/docs/source/design/v1/prefix_caching.md index ec1f3cb8d64a..26c9bfcc4d50 100644 --- a/docs/source/design/v1/prefix_caching.md +++ b/docs/source/design/v1/prefix_caching.md @@ -16,7 +16,7 @@ In the example above, the KV cache in the first block can be uniquely identified * Parent hash value: The hash value of the parent hash block. * Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision. -* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below). +* Extra hashes: Other values required to make this block unique, such as LoRA IDs, multi-modality input hashes (see the example below), and cache salts to isolate caches in multi-tenant environments. > **Note 1:** We only cache full blocks. @@ -76,6 +76,22 @@ Block 3 In the rest of this document, we first introduce the data structure used for prefix caching in vLLM v1, followed by the prefix caching workflow of major KV cache operators (e.g., allocate, append, free, eviction). Finally, we use an example to illustrate the end to end prefix caching workflow. +**Cache Isolation for Security** +To improve privacy in shared environments, vLLM supports isolating prefix cache reuse through optional per-request salting. By including a `cache_salt` in the request, this value is injected into the hash of the first block, ensuring that only requests with the same salt can reuse cached KV blocks. This prevents timing-based attacks where an adversary could infer cached content by observing latency differences. This offers protection without compromising performance. + +```json +{ + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Here is a document with details about the world series: ..."}, + {"role": "user", "content": "Who won the world series in 2020?"} + ], + "cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==" +} +``` + +With this setup, cache sharing is limited to users or requests that explicitly agree on a common salt, enabling cache reuse within a trust group while isolating others. + ## Data Structure The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified): diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 19d16713b209..68e7172938a1 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -272,3 +272,23 @@ def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].temperature == 0.0 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + + # Test cache_salt + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + ) + + # By default cache_salt in the engine prompt is not set + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + assert "cache_salt" not in mock_engine.generate.call_args.args[0] + + # Test with certain cache_salt + req.cache_salt = "test_salt" + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt" diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e73e08e74b0d..e8069b8c6d7f 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -29,7 +29,8 @@ def make_request(request_id, prompt_token_ids, mm_positions=None, - mm_hashes=None): + mm_hashes=None, + cache_salt=None): if mm_positions is None: multi_modal_inputs = None else: @@ -45,6 +46,7 @@ def make_request(request_id, eos_token_id=100, arrival_time=0, lora_request=None, + cache_salt=cache_salt, ) @@ -213,6 +215,45 @@ def test_generate_block_hash_extra_keys_no_mm_inputs(): assert next_mm_idx == 0 +def test_generate_block_hash_extra_keys_cache_salt(): + request = make_request( + request_id=0, + prompt_token_ids=[_ for _ in range(6)], + mm_positions=None, + mm_hashes=None, + cache_salt="salt", + ) + + # salt is added for the first token + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 1, 0) + assert extra_keys == ('salt', ) + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 10, 0) + assert extra_keys == ('salt', ) + + # no salt added for other tokens + extra_keys, _ = generate_block_hash_extra_keys(request, 1, 2, 0) + assert extra_keys is None + extra_keys, _ = generate_block_hash_extra_keys(request, 6, 10, 0) + assert extra_keys is None + + # works together with other extra keys + request_mm = make_request( + request_id=0, + prompt_token_ids=[_ for _ in range(20)], + mm_positions=[ + PlaceholderRange(offset=0, length=5), + ], + mm_hashes=["hash1"], + cache_salt="salt", + ) + + # Test with no extra keys + extra_keys, next_mm_idx = generate_block_hash_extra_keys( + request_mm, 0, 5, 0) + assert extra_keys == ("hash1", "salt") + assert next_mm_idx == 1 + + @pytest.mark.parametrize("hash_fn", [sha256, hash]) def test_hash_block_tokens(hash_fn): parent_block_hash = 123 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index b2e8ff61450c..ae4bd95d22aa 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -21,7 +21,8 @@ def make_request(request_id, prompt_token_ids, mm_positions=None, mm_hashes=None, - prompt_logprobs: Optional[int] = None): + prompt_logprobs: Optional[int] = None, + cache_salt: Optional[str] = None): if mm_positions is None: multi_modal_inputs = None else: @@ -38,6 +39,7 @@ def make_request(request_id, eos_token_id=100, arrival_time=0, lora_request=None, + cache_salt=cache_salt, ) @@ -603,6 +605,66 @@ def test_mm_prefix_caching(): assert num_computed_tokens == 3 * 16 +def test_cache_key_salting(): + """ + This tests that cache salts are applied during hashing and the cache + is separated cache as expected. + """ + block_size = 16 + manager = KVCacheManager( + make_kv_cache_config(block_size, 11), + max_model_len=8192, + enable_caching=True, + ) + + # 3 complete blocks and an incomplete block with 11 tokens. + common_token_ids = [i for i in range(3) for _ in range(block_size)] + token_ids = common_token_ids + [3] * 11 + req0 = make_request("0", token_ids, cache_salt="salt1") + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + + # Completed block should have hashes with extra keys. + assert not computed_blocks + assert num_computed_tokens == 0 + block_hashes = manager.req_to_block_hashes[req0.request_id] + assert len(block_hashes) == 3 + assert block_hashes[0].extra_keys == ("salt1", ) + assert block_hashes[1].extra_keys is None + assert block_hashes[2].extra_keys is None + + blocks = manager.allocate_slots(req0, 59, computed_blocks) + assert [b.block_id for b in blocks] == [1, 2, 3, 4] + req0.num_computed_tokens = 59 + + # Append slots without allocating a new block. + for _ in range(5): + req0.append_output_token_ids(8) + new_blocks = manager.allocate_slots(req0, 5) + assert new_blocks is not None and len(new_blocks) == 0 + + # Now one more block that should not have extra keys. + assert len(block_hashes) == 4 + assert block_hashes[3].extra_keys is None + + # Test cache hit with a new request that has the same salt. + token_ids = common_token_ids + [4] * 11 + req1 = make_request("1", token_ids, cache_salt="salt1") + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + # Should match only a prefix of 3 blocks. + assert len(computed_blocks) == 3 + assert num_computed_tokens == 3 * block_size + + # Test cache miss with same content but different salt. + token_ids = common_token_ids + [4] * 11 + req2 = make_request("2", token_ids, cache_salt="salt2") + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + assert len(computed_blocks) == 0 + assert num_computed_tokens == 0 + block_hashes = manager.req_to_block_hashes[req2.request_id] + assert len(block_hashes) == 3 + assert block_hashes[0].extra_keys == ("salt2", ) + + def test_prefill_not_enough_free_blocks_with_computed_blocks(): """ This is a unit test that tests the correctness of the allocate_slots diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d444442a9762..218a22745209 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -408,6 +408,15 @@ class ChatCompletionRequest(OpenAIBaseModel): "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + cache_salt: Optional[str] = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit).")) # doc: end-chat-completion-extra-params diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index c3121eff562d..477a8f79ac93 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -470,6 +470,9 @@ async def _preprocess_chat( if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + return conversation, [request_prompt], [engine_prompt] def _log_inputs( diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 970b36bca9be..6f9da0dae740 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -141,11 +141,17 @@ class TokenInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + def token_inputs( prompt_token_ids: list[int], token_type_ids: Optional[list[int]] = None, prompt: Optional[str] = None, + cache_salt: Optional[str] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) @@ -154,6 +160,8 @@ def token_inputs( inputs["prompt"] = prompt if token_type_ids is not None: inputs["token_type_ids"] = token_type_ids + if cache_salt is not None: + inputs["cache_salt"] = cache_salt return inputs @@ -217,7 +225,7 @@ def zip_enc_dec_prompts( """ Zip encoder and decoder prompts together into a list of :class:`ExplicitEncoderDecoderPrompt` instances. - + ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same dictionary will be used for every encoder/decoder prompt. If an iterable is provided, it will be zipped with the encoder/decoder prompts. diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 56b60b893913..f9652a413c0e 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -17,7 +17,8 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, token_inputs) -from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt +from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, + is_explicit_encoder_decoder_prompt, parse_singleton_prompt) logger = init_logger(__name__) @@ -235,6 +236,7 @@ def _process_multimodal( mm_processor_kwargs: Optional[Mapping[str, object]], lora_request: Optional[LoRARequest], return_mm_hashes: bool = False, + cache_salt: Optional[str] = None, ) -> MultiModalInputs: """ Apply the model's multi-modal processor to a multi-modal prompt, @@ -254,8 +256,11 @@ def _process_multimodal( if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, - return_mm_hashes) + return mm_processor.apply(prompt, + mm_data, + mm_processor_kwargs, + return_mm_hashes, + cache_salt=cache_salt) async def _process_multimodal_async( self, @@ -264,6 +269,7 @@ async def _process_multimodal_async( mm_processor_kwargs: Optional[Mapping[str, object]], lora_request: Optional[LoRARequest], return_mm_hashes: bool = False, + cache_salt: Optional[str] = None, ) -> MultiModalInputs: """Async version of :meth:`_process_multimodal`.""" # At the moment on model (PrithviGeoSpatialMAE) requires to be @@ -280,8 +286,29 @@ async def _process_multimodal_async( if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, - return_mm_hashes) + return mm_processor.apply(prompt, + mm_data, + mm_processor_kwargs, + return_mm_hashes, + cache_salt=cache_salt) + + def _get_prompt_data(parsed_prompt: Union[ParsedStrPrompt, + ParsedTextPrompt, + ParsedTokensPrompt]): + prompt_text = None + prompt_token_ids = None + token_type_ids = None + + content = parsed_prompt["content"] + if parsed_prompt["type"] == "tokens": + prompt_token_ids = content.get("prompt_token_ids") + token_type_ids = content.get("token_type_ids") + elif parsed_prompt["type"] == "text": + prompt_text = content["prompt"] + else: + assert_never(parsed_prompt) + + return prompt_text, prompt_token_ids, token_type_ids def _prompt_to_llm_inputs( self, @@ -304,40 +331,35 @@ def _prompt_to_llm_inputs( * :class:`SingletonInputs` instance """ parsed = parse_singleton_prompt(prompt) + content = parsed["content"] if parsed["type"] == "str": - prompt_text = parsed["content"] prompt_token_ids = self._tokenize_prompt( - prompt_text, + content, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, ) return token_inputs( - prompt=prompt_text, + prompt=content, prompt_token_ids=prompt_token_ids, ) - if parsed["type"] == "tokens": - tokens_content = parsed["content"] + prompt_text, prompt_token_ids, token_type_ids = self._get_prompt_data( + parsed) - prompt_token_ids = tokens_content["prompt_token_ids"] - token_type_ids = tokens_content.get("token_type_ids") - multi_modal_data = tokens_content.get("multi_modal_data") - mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + multi_modal_data = content.get("multi_modal_data") + mm_processor_kwargs = content.get("mm_processor_kwargs") + cache_salt = content.get("cache_salt") - if multi_modal_data is not None: - return self._process_multimodal( - prompt_token_ids, - multi_modal_data, - mm_processor_kwargs, - lora_request=lora_request, - return_mm_hashes=return_mm_hashes, - ) - - return token_inputs( - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, + if multi_modal_data is not None: + return self._process_multimodal( + prompt_token_ids if prompt_text is None else prompt_text, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + cache_salt=cache_salt, ) if parsed["type"] == "text": @@ -362,12 +384,12 @@ def _prompt_to_llm_inputs( tokenization_kwargs=tokenization_kwargs, ) - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - ) - - assert_never(parsed) + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + cache_salt=cache_salt, + ) async def _prompt_to_llm_inputs_async( self, @@ -378,65 +400,49 @@ async def _prompt_to_llm_inputs_async( ) -> SingletonInputs: """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(prompt) + content = parsed["content"] if parsed["type"] == "str": - prompt_text = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + prompt=content, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, ) return token_inputs( - prompt=prompt_text, + prompt=content, prompt_token_ids=prompt_token_ids, ) - if parsed["type"] == "tokens": - tokens_content = parsed["content"] - - prompt_token_ids = tokens_content["prompt_token_ids"] - multi_modal_data = tokens_content.get("multi_modal_data") - mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - - if multi_modal_data is not None: - return await self._process_multimodal_async( - prompt_token_ids, - multi_modal_data, - mm_processor_kwargs, - lora_request=lora_request, - return_mm_hashes=return_mm_hashes, - ) - - return token_inputs(prompt_token_ids=prompt_token_ids) + prompt_text, prompt_token_ids, token_type_ids = self._get_prompt_data( + parsed) - if parsed["type"] == "text": - text_content = parsed["content"] + multi_modal_data = content.get("multi_modal_data") + mm_processor_kwargs = content.get("mm_processor_kwargs") + cache_salt = content.get("cache_salt") - prompt_text = text_content["prompt"] - multi_modal_data = text_content.get("multi_modal_data") - mm_processor_kwargs = text_content.get("mm_processor_kwargs") - - if multi_modal_data is not None: - return await self._process_multimodal_async( - prompt_text, - multi_modal_data, - mm_processor_kwargs, - lora_request=lora_request, - return_mm_hashes=return_mm_hashes, - ) - - prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + if multi_modal_data is not None: + return await self._process_multimodal_async( + prompt_token_ids if prompt_text is None else prompt_text, + multi_modal_data, + mm_processor_kwargs, lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + cache_salt=cache_salt, ) - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, + if prompt_token_ids is None: + prompt_token_ids = self._tokenize_prompt_async( + prompt_text, + lora_request=lora_request, ) - assert_never(parsed) + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + cache_salt=cache_salt, + ) def _build_enc_dec_llm_inputs( self, diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 6855808e8e44..6c82c9a802bc 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -826,6 +826,11 @@ class MultiModalInputs(TypedDict): :code:`prompt_token_ids`. """ + cache_salt: Optional[str] + """ + Optional cache salt to be used for prefix caching. + """ + class MultiModalEncDecInputs(MultiModalInputs): """ diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index d6ba8f1bcffe..077585947c98 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1777,6 +1777,7 @@ def apply( mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, mm_placeholders=mm_placeholder_ranges, + cache_salt=cache_salt, ) @@ -1789,7 +1790,7 @@ def create_encoder_prompt( mm_data: MultiModalDataDict, ) -> Union[str, list[int]]: """ - Create input prompt for the encoder. HF processor will be applied on + Create input prompt for the encoder. HF processor will be applied on this prompt during profiling and generation. """ raise NotImplementedError diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3026ecc1c968..27c515835087 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -275,7 +275,10 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. - return bool(request.mm_positions) or (request.lora_request is not None) + # Request with provided cache salt need to include the salt. + return bool(request.mm_positions) or (request.lora_request + is not None) or (request.cache_salt + is not None) def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, @@ -380,8 +383,10 @@ def generate_block_hash_extra_keys( mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( request, start_token_idx, end_token_idx, start_mm_idx) lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) + cache_salt_keys: list[str] = [request.cache_salt] if ( + start_token_idx == 0 and request.cache_salt) else [] - extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys if not extra_keys: return None, new_start_mm_idx @@ -657,10 +662,10 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ Only models with one type of KV cache are supported yet. This function tries - to convert the KV cache specs to one type if the model is a hybrid model + to convert the KV cache specs to one type if the model is a hybrid model with multiple type of KV cache. It will convert all SlidingWindowSpec to FullAttentionSpec if both types are present. - + Args: kv_cache_spec: The kv cache spec of each attention layer in the model """ diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 0474669610cd..e33d1a1e5dcd 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -57,6 +57,7 @@ class EngineCoreRequest( eos_token_id: Optional[int] arrival_time: float lora_request: Optional[LoRARequest] + cache_salt: Optional[str] # Used in DP case to indicate which wave of requests this is expected to # belong to, to cover a race condition where the request is sent before diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index b98a31773a15..143bc8285ebb 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -317,6 +317,7 @@ def process_inputs( eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, + cache_salt=processed_inputs.get("cache_salt"), ) def _validate_model_inputs(self, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 3b9b666f936a..fde366d61c7d 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -29,6 +29,7 @@ def __init__( arrival_time: float, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, + cache_salt: Optional[str] = None, ) -> None: self.request_id = request_id self.sampling_params = sampling_params @@ -51,6 +52,7 @@ def __init__( self._all_token_ids: list[int] = self.prompt_token_ids.copy() self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 + self.cache_salt: Optional[str] = cache_salt # Multi-modal related self.mm_positions = multi_modal_placeholders or [] @@ -89,6 +91,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), + cache_salt=request.cache_salt, ) def append_output_token_ids( From 1b4032b5191fbd441ac7896cfdcbd029f34b7b67 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Wed, 23 Apr 2025 11:30:16 +0000 Subject: [PATCH 02/18] fix missing arg Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/data.py | 10 ++++++++++ vllm/inputs/preprocess.py | 27 +++++++++++++++------------ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 6f9da0dae740..167189ed108e 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -28,6 +28,11 @@ class TextPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + class TokensPrompt(TypedDict): """Schema for a tokenized prompt.""" @@ -52,6 +57,11 @@ class TokensPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index f9652a413c0e..fbfb36dbb7f6 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -17,7 +17,7 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, token_inputs) -from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, +from .parse import (ParsedTextPrompt, ParsedTokensPrompt, is_explicit_encoder_decoder_prompt, parse_singleton_prompt) logger = init_logger(__name__) @@ -292,19 +292,18 @@ async def _process_multimodal_async( return_mm_hashes, cache_salt=cache_salt) - def _get_prompt_data(parsed_prompt: Union[ParsedStrPrompt, - ParsedTextPrompt, - ParsedTokensPrompt]): + def _get_prompt_data(self, parsed_prompt: Union[ParsedTextPrompt, + ParsedTokensPrompt]): prompt_text = None prompt_token_ids = None token_type_ids = None - content = parsed_prompt["content"] if parsed_prompt["type"] == "tokens": + content = parsed_prompt["content"] prompt_token_ids = content.get("prompt_token_ids") token_type_ids = content.get("token_type_ids") elif parsed_prompt["type"] == "text": - prompt_text = content["prompt"] + prompt_text = parsed_prompt["content"]["prompt"] else: assert_never(parsed_prompt) @@ -331,23 +330,24 @@ def _prompt_to_llm_inputs( * :class:`SingletonInputs` instance """ parsed = parse_singleton_prompt(prompt) - content = parsed["content"] if parsed["type"] == "str": + prompt_text = parsed["content"] prompt_token_ids = self._tokenize_prompt( - content, + prompt_text, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, ) return token_inputs( - prompt=content, + prompt=prompt_text, prompt_token_ids=prompt_token_ids, ) prompt_text, prompt_token_ids, token_type_ids = self._get_prompt_data( parsed) + content = parsed["content"] multi_modal_data = content.get("multi_modal_data") mm_processor_kwargs = content.get("mm_processor_kwargs") cache_salt = content.get("cache_salt") @@ -400,23 +400,24 @@ async def _prompt_to_llm_inputs_async( ) -> SingletonInputs: """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(prompt) - content = parsed["content"] if parsed["type"] == "str": + prompt_text = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt=content, + prompt=prompt_text, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, ) return token_inputs( - prompt=content, + prompt=prompt_text, prompt_token_ids=prompt_token_ids, ) prompt_text, prompt_token_ids, token_type_ids = self._get_prompt_data( parsed) + content = parsed["content"] multi_modal_data = content.get("multi_modal_data") mm_processor_kwargs = content.get("mm_processor_kwargs") cache_salt = content.get("cache_salt") @@ -512,6 +513,7 @@ def _separate_enc_dec_inputs_from_mm_processor_outputs( mm_kwargs=inputs["mm_kwargs"], mm_hashes=inputs["mm_hashes"], mm_placeholders=inputs["mm_placeholders"], + cache_salt=inputs["cache_salt"], ) else: decoder_inputs = MultiModalInputs( @@ -521,6 +523,7 @@ def _separate_enc_dec_inputs_from_mm_processor_outputs( mm_kwargs=inputs["mm_kwargs"], mm_hashes=inputs["mm_hashes"], mm_placeholders=inputs["mm_placeholders"], + cache_salt=inputs["cache_salt"], ) elif inputs["type"] == "token": # Text-only inputs From c3c728607f8abd41e7770bc6d0a52bd656d6fc39 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 07:46:28 +0000 Subject: [PATCH 03/18] return error for cache_salt with V0 engine Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- docs/source/design/v1/prefix_caching.md | 2 ++ tests/entrypoints/openai/test_serving_chat.py | 20 +++++++++++++++ vllm/entrypoints/openai/protocol.py | 25 +++++++++++++++---- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/docs/source/design/v1/prefix_caching.md b/docs/source/design/v1/prefix_caching.md index 26c9bfcc4d50..ec661d8ec641 100644 --- a/docs/source/design/v1/prefix_caching.md +++ b/docs/source/design/v1/prefix_caching.md @@ -92,6 +92,8 @@ To improve privacy in shared environments, vLLM supports isolating prefix cache With this setup, cache sharing is limited to users or requests that explicitly agree on a common salt, enabling cache reuse within a trust group while isolating others. +> **Note:** Cache isolation is not supported in engine V0. + ## Data Structure The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified): diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 68e7172938a1..5e11af8cf892 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -273,6 +273,26 @@ def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].temperature == 0.0 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + +def test_serving_chat_did_set_correct_cache_salt(): + mock_model_config = MockModelConfig() + + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + # Initialize the serving chat + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) + serving_chat = OpenAIServingChat(mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + # Test cache_salt req = ChatCompletionRequest( model=MODEL_NAME, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 218a22745209..d514eeff09e5 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -14,6 +14,7 @@ ValidationInfo, field_validator, model_validator) from typing_extensions import TypeAlias +from vllm import envs from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.logger import init_logger from vllm.pooling_params import PoolingParams @@ -416,7 +417,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit).")) + "to 256 bit). Not supported by vLLM engine V0.")) # doc: end-chat-completion-extra-params @@ -735,6 +736,20 @@ def check_generation_prompt(cls, data): "`add_generation_prompt` to True.") return data + @model_validator(mode="before") + @classmethod + def check_cache_salt_support(cls, data): + if data.get("cache_salt") is not None: + if not envs.VLLM_USE_V1: + raise ValueError( + "Parameter 'cache_salt' is not supported with " + "this instance of vLLM, which uses engine V0.") + if not isinstance(data["cache_salt"], str) or len( + data["cache_salt"]) == 0: + raise ValueError("Parameter 'cache_salt' must be a " + "non-empty string if provided.") + return data + class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -1631,9 +1646,9 @@ class TranscriptionRequest(OpenAIBaseModel): # doc: begin-transcription-extra-params stream: Optional[bool] = False - """Custom field not present in the original OpenAI definition. When set, + """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat - Completion endpoint. + Completion endpoint. """ # Flattened stream option to simplify form data. stream_include_usage: Optional[bool] = False @@ -1651,7 +1666,7 @@ class TranscriptionRequest(OpenAIBaseModel): """ top_p: Optional[float] = None - """Enables nucleus (top-p) sampling, where tokens are selected from the + """Enables nucleus (top-p) sampling, where tokens are selected from the smallest possible set whose cumulative probability exceeds `p`. """ @@ -1659,7 +1674,7 @@ class TranscriptionRequest(OpenAIBaseModel): """Limits sampling to the `k` most probable tokens at each step.""" min_p: Optional[float] = None - """Filters out tokens with a probability lower than `min_p`, ensuring a + """Filters out tokens with a probability lower than `min_p`, ensuring a minimum likelihood threshold during sampling. """ From fb85c235168a1f9c625e1ef9c9fd0f38668af4ac Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 08:42:05 +0000 Subject: [PATCH 04/18] merge conflict Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/preprocess.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index fbfb36dbb7f6..11ac18d8a33f 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -362,22 +362,7 @@ def _prompt_to_llm_inputs( cache_salt=cache_salt, ) - if parsed["type"] == "text": - text_content = parsed["content"] - - prompt_text = text_content["prompt"] - multi_modal_data = text_content.get("multi_modal_data") - mm_processor_kwargs = text_content.get("mm_processor_kwargs") - - if multi_modal_data is not None: - return self._process_multimodal( - prompt_text, - multi_modal_data, - mm_processor_kwargs, - lora_request=lora_request, - return_mm_hashes=return_mm_hashes, - ) - + if prompt_token_ids is None: prompt_token_ids = self._tokenize_prompt( prompt_text, lora_request=lora_request, From 5896ba0740bde4bc2a7d3ac3f016f1734c65f122 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 09:01:31 +0000 Subject: [PATCH 05/18] refactor prompt input processing Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/preprocess.py | 44 ++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 11ac18d8a33f..6a937a2604e2 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -17,7 +17,7 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, token_inputs) -from .parse import (ParsedTextPrompt, ParsedTokensPrompt, +from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, is_explicit_encoder_decoder_prompt, parse_singleton_prompt) logger = init_logger(__name__) @@ -292,22 +292,28 @@ async def _process_multimodal_async( return_mm_hashes, cache_salt=cache_salt) - def _get_prompt_data(self, parsed_prompt: Union[ParsedTextPrompt, + def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt, + ParsedTextPrompt, ParsedTokensPrompt]): prompt_text = None prompt_token_ids = None token_type_ids = None + cache_salt = None - if parsed_prompt["type"] == "tokens": - content = parsed_prompt["content"] - prompt_token_ids = content.get("prompt_token_ids") - token_type_ids = content.get("token_type_ids") - elif parsed_prompt["type"] == "text": - prompt_text = parsed_prompt["content"]["prompt"] + if parsed_prompt["type"] == "str": + prompt_text = parsed_prompt["content"] else: - assert_never(parsed_prompt) + content = parsed_prompt["content"] + cache_salt = content.get("cache_salt") + if parsed_prompt["type"] == "text": + prompt_text = content["prompt"] + elif parsed_prompt["type"] == "tokens": + prompt_token_ids = content.get("prompt_token_ids") + token_type_ids = content.get("token_type_ids") + else: + assert_never(parsed_prompt) - return prompt_text, prompt_token_ids, token_type_ids + return prompt_text, prompt_token_ids, token_type_ids, cache_salt def _prompt_to_llm_inputs( self, @@ -354,9 +360,9 @@ def _prompt_to_llm_inputs( if multi_modal_data is not None: return self._process_multimodal( - prompt_token_ids if prompt_text is None else prompt_text, - multi_modal_data, - mm_processor_kwargs, + prompt_text if prompt_text is not None else prompt_token_ids, + parsed["content"].get("multi_modal_data"), + parsed["content"].get("mm_processor_kwargs"), lora_request=lora_request, return_mm_hashes=return_mm_hashes, cache_salt=cache_salt, @@ -402,16 +408,12 @@ async def _prompt_to_llm_inputs_async( prompt_text, prompt_token_ids, token_type_ids = self._get_prompt_data( parsed) - content = parsed["content"] - multi_modal_data = content.get("multi_modal_data") - mm_processor_kwargs = content.get("mm_processor_kwargs") - cache_salt = content.get("cache_salt") - - if multi_modal_data is not None: + if parsed["type"] != "str" and "multi_modal_data" in parsed[ + "content"] is not None: return await self._process_multimodal_async( prompt_token_ids if prompt_text is None else prompt_text, - multi_modal_data, - mm_processor_kwargs, + parsed["content"].get("multi_modal_data"), + parsed["content"].get("mm_processor_kwargs"), lora_request=lora_request, return_mm_hashes=return_mm_hashes, cache_salt=cache_salt, From 42e457ff52ada2e229da542ca2cea3aa393557c8 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:03:04 +0000 Subject: [PATCH 06/18] don't pass salt into multi-modal processor Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/preprocess.py | 28 +++++++++++++--------------- vllm/multimodal/processing.py | 1 - 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 6a937a2604e2..d176b31480f9 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -256,11 +256,10 @@ def _process_multimodal( if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, - mm_data, - mm_processor_kwargs, - return_mm_hashes, - cache_salt=cache_salt) + inputs = mm_processor.apply(prompt, mm_data, mm_processor_kwargs, + return_mm_hashes) + inputs["cache_salt"] = cache_salt + return inputs async def _process_multimodal_async( self, @@ -286,11 +285,10 @@ async def _process_multimodal_async( if mm_processor_kwargs is None: mm_processor_kwargs = {} - return mm_processor.apply(prompt, - mm_data, - mm_processor_kwargs, - return_mm_hashes, - cache_salt=cache_salt) + inputs = mm_processor.apply(prompt, mm_data, mm_processor_kwargs, + return_mm_hashes) + inputs["cache_salt"] = cache_salt + return inputs def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt, ParsedTextPrompt, @@ -303,13 +301,13 @@ def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt, if parsed_prompt["type"] == "str": prompt_text = parsed_prompt["content"] else: - content = parsed_prompt["content"] - cache_salt = content.get("cache_salt") + cache_salt = parsed_prompt["content"].get("cache_salt") if parsed_prompt["type"] == "text": - prompt_text = content["prompt"] + prompt_text = parsed_prompt["content"]["prompt"] elif parsed_prompt["type"] == "tokens": - prompt_token_ids = content.get("prompt_token_ids") - token_type_ids = content.get("token_type_ids") + prompt_token_ids = parsed_prompt["content"].get( + "prompt_token_ids") + token_type_ids = parsed_prompt["content"].get("token_type_ids") else: assert_never(parsed_prompt) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 077585947c98..e8745a8f1f90 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1777,7 +1777,6 @@ def apply( mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, mm_placeholders=mm_placeholder_ranges, - cache_salt=cache_salt, ) From e8e370edceec861bd18474f4f5af86045d0f91a9 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:11:33 +0000 Subject: [PATCH 07/18] test for valid inputs Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/preprocess.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index d176b31480f9..5cc7008f0e21 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -258,7 +258,8 @@ def _process_multimodal( inputs = mm_processor.apply(prompt, mm_data, mm_processor_kwargs, return_mm_hashes) - inputs["cache_salt"] = cache_salt + if inputs is not None: + inputs["cache_salt"] = cache_salt return inputs async def _process_multimodal_async( @@ -287,7 +288,8 @@ async def _process_multimodal_async( inputs = mm_processor.apply(prompt, mm_data, mm_processor_kwargs, return_mm_hashes) - inputs["cache_salt"] = cache_salt + if inputs is not None: + inputs["cache_salt"] = cache_salt return inputs def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt, From 2e55fb17423fc60025b678213e46b0879dbd5541 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:30:09 +0000 Subject: [PATCH 08/18] make cache salt not required in MM inputs Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/multimodal/inputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 6c82c9a802bc..978fb4231939 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -826,7 +826,7 @@ class MultiModalInputs(TypedDict): :code:`prompt_token_ids`. """ - cache_salt: Optional[str] + cache_salt: NotRequired[str] """ Optional cache salt to be used for prefix caching. """ From 8c2f2de0d166ed85b3a081b1b5c101ea2fb5c430 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 11:02:44 +0000 Subject: [PATCH 09/18] fix type errors Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/preprocess.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 5cc7008f0e21..e2d5f12aac70 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -258,7 +258,7 @@ def _process_multimodal( inputs = mm_processor.apply(prompt, mm_data, mm_processor_kwargs, return_mm_hashes) - if inputs is not None: + if cache_salt is not None: inputs["cache_salt"] = cache_salt return inputs @@ -288,7 +288,7 @@ async def _process_multimodal_async( inputs = mm_processor.apply(prompt, mm_data, mm_processor_kwargs, return_mm_hashes) - if inputs is not None: + if cache_salt is not None: inputs["cache_salt"] = cache_salt return inputs @@ -361,7 +361,7 @@ def _prompt_to_llm_inputs( if multi_modal_data is not None: return self._process_multimodal( prompt_text if prompt_text is not None else prompt_token_ids, - parsed["content"].get("multi_modal_data"), + parsed["content"]["multi_modal_data"], parsed["content"].get("mm_processor_kwargs"), lora_request=lora_request, return_mm_hashes=return_mm_hashes, @@ -412,7 +412,7 @@ async def _prompt_to_llm_inputs_async( "content"] is not None: return await self._process_multimodal_async( prompt_token_ids if prompt_text is None else prompt_text, - parsed["content"].get("multi_modal_data"), + parsed["content"]["multi_modal_data"], parsed["content"].get("mm_processor_kwargs"), lora_request=lora_request, return_mm_hashes=return_mm_hashes, From dc16084e77f2273f0b5c506a79e1a4827fafdc5f Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 14:56:03 +0000 Subject: [PATCH 10/18] fix: access cache_salt if not available Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/entrypoints/openai/serving_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 477a8f79ac93..6123811aabe1 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -470,7 +470,7 @@ async def _preprocess_chat( if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs - if request.cache_salt is not None: + if hasattr(request, "cache_salt") and request.cache_salt is not None: engine_prompt["cache_salt"] = request.cache_salt return conversation, [request_prompt], [engine_prompt] From 3a5b96d93e656bad469033ebd97867c6ca668ee5 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 17:28:00 +0000 Subject: [PATCH 11/18] fix check for multi_modal_data Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/preprocess.py | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index e2d5f12aac70..8fa910da3c0e 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -336,29 +336,12 @@ def _prompt_to_llm_inputs( * :class:`SingletonInputs` instance """ parsed = parse_singleton_prompt(prompt) + prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ + self._get_prompt_data(parsed) - if parsed["type"] == "str": - prompt_text = parsed["content"] - prompt_token_ids = self._tokenize_prompt( - prompt_text, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) - - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - ) - - prompt_text, prompt_token_ids, token_type_ids = self._get_prompt_data( - parsed) - - content = parsed["content"] - multi_modal_data = content.get("multi_modal_data") - mm_processor_kwargs = content.get("mm_processor_kwargs") - cache_salt = content.get("cache_salt") - - if multi_modal_data is not None: + # If multimodal data is present, process and return immediately + if parsed["type"] != "str" and parsed["content"].get( + "multi_modal_data") is not None: return self._process_multimodal( prompt_text if prompt_text is not None else prompt_token_ids, parsed["content"]["multi_modal_data"], @@ -408,8 +391,8 @@ async def _prompt_to_llm_inputs_async( prompt_text, prompt_token_ids, token_type_ids = self._get_prompt_data( parsed) - if parsed["type"] != "str" and "multi_modal_data" in parsed[ - "content"] is not None: + if parsed["type"] != "str" and parsed["content"].get( + "multi_modal_data") is not None: return await self._process_multimodal_async( prompt_token_ids if prompt_text is None else prompt_text, parsed["content"]["multi_modal_data"], From d7cfc0a54c67e14d8574b577e10e4ab9e9a130ce Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 22:16:57 +0000 Subject: [PATCH 12/18] fix: get cache_salt from decoder inputs Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/v1/engine/processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 143bc8285ebb..27d70a781471 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -317,7 +317,7 @@ def process_inputs( eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, - cache_salt=processed_inputs.get("cache_salt"), + cache_salt=decoder_inputs.get("cache_salt"), ) def _validate_model_inputs(self, From 5680139e53a7be92eedd5ae9c98e9eef0ad7f2e9 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Mon, 28 Apr 2025 22:39:28 +0000 Subject: [PATCH 13/18] review comment Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/entrypoints/openai/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d514eeff09e5..389557dfb7c3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -744,8 +744,8 @@ def check_cache_salt_support(cls, data): raise ValueError( "Parameter 'cache_salt' is not supported with " "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], str) or len( - data["cache_salt"]) == 0: + if not isinstance(data["cache_salt"], + str) or not data["cache_salt"]: raise ValueError("Parameter 'cache_salt' must be a " "non-empty string if provided.") return data From 6a50d146eee702959cde84149755f0d69873e5eb Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Tue, 29 Apr 2025 07:46:11 +0000 Subject: [PATCH 14/18] fix cache_salt access in V0 Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/preprocess.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 8fa910da3c0e..8b3e1bf62e21 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -483,7 +483,6 @@ def _separate_enc_dec_inputs_from_mm_processor_outputs( mm_kwargs=inputs["mm_kwargs"], mm_hashes=inputs["mm_hashes"], mm_placeholders=inputs["mm_placeholders"], - cache_salt=inputs["cache_salt"], ) else: decoder_inputs = MultiModalInputs( @@ -493,8 +492,12 @@ def _separate_enc_dec_inputs_from_mm_processor_outputs( mm_kwargs=inputs["mm_kwargs"], mm_hashes=inputs["mm_hashes"], mm_placeholders=inputs["mm_placeholders"], - cache_salt=inputs["cache_salt"], ) + + cache_salt = inputs.get("cache_salt") + if cache_salt is not None: + decoder_inputs["cache_salt"] = cache_salt + elif inputs["type"] == "token": # Text-only inputs encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) From f53963b1810a916f486932d3a0aa2fee7922bb68 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Tue, 29 Apr 2025 15:05:48 +0000 Subject: [PATCH 15/18] fix: missing await Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 8b3e1bf62e21..e2553cf0bb76 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -403,7 +403,7 @@ async def _prompt_to_llm_inputs_async( ) if prompt_token_ids is None: - prompt_token_ids = self._tokenize_prompt_async( + prompt_token_ids = await self._tokenize_prompt_async( prompt_text, lora_request=lora_request, ) From ef0ab820336025bcf4822b5167c288715c50c947 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Tue, 29 Apr 2025 20:53:22 +0000 Subject: [PATCH 16/18] fix tests Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- tests/tokenization/test_detokenize.py | 12 ++++++++++-- tests/v1/engine/test_engine_core.py | 1 + tests/v1/engine/test_engine_core_client.py | 1 + tests/v1/engine/test_output_processor.py | 7 ++++++- 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index f8e213b9ca48..079100e78b5f 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -60,8 +60,16 @@ def _run_incremental_decode(tokenizer, skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) - request = EngineCoreRequest("", prompt_token_ids, None, None, None, params, - None, 0.0, None) + request = EngineCoreRequest("", + prompt_token_ids, + None, + None, + None, + params, + None, + 0.0, + None, + cache_salt=None) if fast is None: detokenizer = IncrementalDetokenizer.from_new_request( diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 30fa9e371ad1..dcf494825b0d 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -40,6 +40,7 @@ def make_request() -> EngineCoreRequest: eos_token_id=None, arrival_time=time.time(), lora_request=None, + cache_salt=None, ) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 8cc36fa163f7..5514a328497f 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -43,6 +43,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest: eos_token_id=None, arrival_time=time.time(), lora_request=None, + cache_salt=None, ) diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index d2bb7d88fef2..fac701c4ca35 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -57,6 +57,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, mm_placeholders=None, eos_token_id=None, lora_request=None, + cache_salt=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -403,6 +404,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, mm_placeholders=None, eos_token_id=None, lora_request=None, + cache_salt=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -503,7 +505,7 @@ def test_stop_token(include_stop_str_in_output: bool, reason should be "stop" (i.e. first control token causes stop and is represented in output text) - * else, the detokenized string should be + * else, the detokenized string should be ... and the finish reason should be "stop" (i.e. first control token causes stop but is not represented in output text.) @@ -565,6 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool, mm_placeholders=None, eos_token_id=eos_token_id, lora_request=None, + cache_salt=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -661,6 +664,7 @@ def test_stop_string(include_stop_str_in_output: bool, mm_placeholders=None, eos_token_id=None, lora_request=None, + cache_salt=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -774,6 +778,7 @@ def test_iteration_stats(dummy_test_vectors): mm_placeholders=None, eos_token_id=None, lora_request=None, + cache_salt=None, sampling_params=SamplingParams(), ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] From af6bfccfe9a90333d892e5493b28253f258b1382 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Wed, 30 Apr 2025 06:20:37 +0000 Subject: [PATCH 17/18] merge conflict Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/preprocess.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index e2553cf0bb76..21950a25aaff 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -375,21 +375,8 @@ async def _prompt_to_llm_inputs_async( """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(prompt) - if parsed["type"] == "str": - prompt_text = parsed["content"] - prompt_token_ids = await self._tokenize_prompt_async( - prompt=prompt_text, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) - - return token_inputs( - prompt=prompt_text, - prompt_token_ids=prompt_token_ids, - ) - - prompt_text, prompt_token_ids, token_type_ids = self._get_prompt_data( - parsed) + prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ + self._get_prompt_data(parsed) if parsed["type"] != "str" and parsed["content"].get( "multi_modal_data") is not None: @@ -406,6 +393,7 @@ async def _prompt_to_llm_inputs_async( prompt_token_ids = await self._tokenize_prompt_async( prompt_text, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, ) return token_inputs( From 18a9c986e98482202e631f77e475ffe65acac834 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Wed, 30 Apr 2025 07:33:31 +0000 Subject: [PATCH 18/18] apply cache salt after processing multimodal inputs Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/inputs/preprocess.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 21950a25aaff..83e6907f8c49 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -236,7 +236,6 @@ def _process_multimodal( mm_processor_kwargs: Optional[Mapping[str, object]], lora_request: Optional[LoRARequest], return_mm_hashes: bool = False, - cache_salt: Optional[str] = None, ) -> MultiModalInputs: """ Apply the model's multi-modal processor to a multi-modal prompt, @@ -256,11 +255,8 @@ def _process_multimodal( if mm_processor_kwargs is None: mm_processor_kwargs = {} - inputs = mm_processor.apply(prompt, mm_data, mm_processor_kwargs, - return_mm_hashes) - if cache_salt is not None: - inputs["cache_salt"] = cache_salt - return inputs + return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, + return_mm_hashes) async def _process_multimodal_async( self, @@ -269,7 +265,6 @@ async def _process_multimodal_async( mm_processor_kwargs: Optional[Mapping[str, object]], lora_request: Optional[LoRARequest], return_mm_hashes: bool = False, - cache_salt: Optional[str] = None, ) -> MultiModalInputs: """Async version of :meth:`_process_multimodal`.""" # At the moment on model (PrithviGeoSpatialMAE) requires to be @@ -286,11 +281,8 @@ async def _process_multimodal_async( if mm_processor_kwargs is None: mm_processor_kwargs = {} - inputs = mm_processor.apply(prompt, mm_data, mm_processor_kwargs, - return_mm_hashes) - if cache_salt is not None: - inputs["cache_salt"] = cache_salt - return inputs + return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, + return_mm_hashes) def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt, ParsedTextPrompt, @@ -342,14 +334,16 @@ def _prompt_to_llm_inputs( # If multimodal data is present, process and return immediately if parsed["type"] != "str" and parsed["content"].get( "multi_modal_data") is not None: - return self._process_multimodal( + inputs = self._process_multimodal( prompt_text if prompt_text is not None else prompt_token_ids, parsed["content"]["multi_modal_data"], parsed["content"].get("mm_processor_kwargs"), lora_request=lora_request, return_mm_hashes=return_mm_hashes, - cache_salt=cache_salt, ) + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + return inputs if prompt_token_ids is None: prompt_token_ids = self._tokenize_prompt( @@ -380,14 +374,16 @@ async def _prompt_to_llm_inputs_async( if parsed["type"] != "str" and parsed["content"].get( "multi_modal_data") is not None: - return await self._process_multimodal_async( + inputs = await self._process_multimodal_async( prompt_token_ids if prompt_text is None else prompt_text, parsed["content"]["multi_modal_data"], parsed["content"].get("mm_processor_kwargs"), lora_request=lora_request, return_mm_hashes=return_mm_hashes, - cache_salt=cache_salt, ) + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + return inputs if prompt_token_ids is None: prompt_token_ids = await self._tokenize_prompt_async(