Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/inputs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def _prompt_to_llm_inputs(
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)

prompt_token_ids = self._tokenize_prompt(
Expand All @@ -401,6 +402,7 @@ async def _prompt_to_llm_inputs_async(
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt)
Expand Down Expand Up @@ -431,6 +433,7 @@ async def _prompt_to_llm_inputs_async(
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)

return token_inputs(
Expand All @@ -452,6 +455,7 @@ async def _prompt_to_llm_inputs_async(
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)

prompt_token_ids = await self._tokenize_prompt_async(
Expand Down Expand Up @@ -726,6 +730,7 @@ def _process_decoder_only_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)

return self._build_decoder_only_llm_inputs(
Expand All @@ -746,6 +751,7 @@ async def _process_decoder_only_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)

return self._build_decoder_only_llm_inputs(
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class EngineCoreRequest(
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str]
prompt_token_ids: list[int]
mm_inputs: Optional[list[Optional[MultiModalKwargs]]]
mm_inputs: Optional[list[MultiModalKwargs]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams
Expand Down
122 changes: 10 additions & 112 deletions vllm/v1/engine/mm_input_cache.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,30 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Optional

from vllm.config import ModelConfig
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache

logger = init_logger(__name__)

# The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- Client:
# - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs.
# - Perform caching of the generated MultiModalKwargs.
# - This client can be deprecated once all mutimodal models migrate to use
# merged preprocessor with built-in caching functionality.
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
#
# -- Server:
# - Perform caching of the received MultiModalKwargs.
# - MMInputCacheServer to perform caching of the received MultiModalKwargs.
#
# The caching for both client and server is mirrored/similar, and this allows us
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# client (=P0) and server (=P1) processes.
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
# cache.

# Both Client and Server must use the same cache size
# (to perform mirrored caching). This cache size is set by the environment
# variable VLLM_MM_INPUT_CACHE_GIB.


# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
# merged preprocessor with built-in caching functionality.
class MMInputCacheClient:

def __init__(
self,
model_config: ModelConfig,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.model_config = model_config
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)

# Init cache
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)

# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
self.mm_debug_cache_hits = 0
self.mm_debug_cache_total = 0

def cache_hit_ratio(self, steps):
total = self.mm_debug_cache_total

if total > 0 and total % steps == 0:
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_debug_cache_hits / total)

# NOTE: process_inputs only supports image inputs since all multimodal
# models with other modalities have migrated to use merged preprocessor.
def process_inputs(
self,
mm_data: MultiModalDataDict,
mm_hashes: Optional[list[str]],
mm_processor_kwargs: Optional[dict[str, Any]],
precomputed_mm_inputs: Optional[list[MultiModalKwargs]],
) -> list[Optional[MultiModalKwargs]]:
if precomputed_mm_inputs is None:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
num_inputs = len(image_inputs)
else:
num_inputs = len(precomputed_mm_inputs)

# Sanity
if self.use_cache:
assert mm_hashes is not None
assert num_inputs == len(mm_hashes)

# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_inputs: list[Optional[MultiModalKwargs]] = []
for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)

mm_input = None
if self.use_cache:
assert mm_hashes is not None
mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash)

self.mm_debug_cache_total += 1
if mm_input is None:
if precomputed_mm_inputs is not None:
# Reuse precomputed input (for merged preprocessor)
mm_input = precomputed_mm_inputs[input_id]
else:
# Apply legacy input_mapper
mm_input = self.multi_modal_input_mapper(
{"image": [image_inputs[input_id]]},
mm_processor_kwargs=mm_processor_kwargs,
)

if self.use_cache:
# Add to cache
assert mm_hash is not None
self.mm_cache[mm_hash] = mm_input
else:
self.mm_debug_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server

ret_inputs.append(mm_input)

return ret_inputs


class MMInputCacheServer:

def __init__(self, model_config):
Expand All @@ -135,9 +34,9 @@ def __init__(self, model_config):

def get_and_update(
self,
mm_inputs: list[Optional[MultiModalKwargs]],
mm_inputs: list[MultiModalKwargs],
mm_hashes: list[str],
) -> list[Optional[MultiModalKwargs]]:
) -> list[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)

if not self.use_cache:
Expand All @@ -147,8 +46,7 @@ def get_and_update(
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
if mm_input is None:
mm_input = self.mm_cache.get(mm_hash)
assert mm_input is not None
mm_input = self.mm_cache[mm_hash]
else:
self.mm_cache[mm_hash] = mm_input

Expand Down
80 changes: 23 additions & 57 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalHasher,
MultiModalKwargs, MultiModalRegistry)
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.structured_output.utils import validate_structured_output_request


Expand All @@ -45,11 +45,6 @@ def __init__(
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.input_processor = input_registry.create_input_processor(
self.model_config)

# Multi-modal (huggingface) input mapper
self.mm_input_cache_client = MMInputCacheClient(self.model_config)

# Multi-modal hasher (for images)
self.use_hash = (
Expand Down Expand Up @@ -171,7 +166,7 @@ def process_inputs(
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
# 3. Apply prompt adapter to prompt token ids if one exists.
preprocessed_inputs = self.input_preprocessor.preprocess(
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
lora_request=lora_request,
Expand All @@ -180,10 +175,6 @@ def process_inputs(
)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)

# Process prompt and prompt token ids.
# Only applicable to multimodal models with legacy input processor.
processed_inputs = self.input_processor(preprocessed_inputs)

self._validate_model_inputs(processed_inputs, lora_request)

if is_encoder_decoder_inputs(processed_inputs):
Expand Down Expand Up @@ -212,36 +203,22 @@ def process_inputs(
self.tokenizer.get_lora_tokenizer(lora_request))

# Multimodal related.
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
# Use mm_hashes from processed inputs if the model has merged
# input processor.
if decoder_inputs.multi_modal_hashes:
mm_hashes = decoder_inputs.multi_modal_hashes
# Fallback to using MultiModalHasher directly.
else:
mm_hashes = MultiModalHasher.hash_prompt_mm_data(prompt)
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if (decoder_mm_inputs := decoder_inputs.multi_modal_data):
assert isinstance(decoder_mm_inputs, MultiModalKwargs)

# For merged preprocessor, mm_data is already mm_inputs
precomputed_mm_inputs: Optional[list[MultiModalKwargs]] = None
decoder_mm_data = decoder_inputs.multi_modal_data
if isinstance(decoder_mm_data, MultiModalKwargs):
# The output of merged multi-modal processor (`decoder_mm_data`)
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
precomputed_mm_inputs = [
individual_mm_inputs = [
MultiModalKwargs.from_items([item])
for modality in decoder_mm_data.modalities
for item in decoder_mm_data.get_items(modality)
for modality in decoder_mm_inputs.modalities
for item in decoder_mm_inputs.get_items(modality)
]

mm_positions = decoder_inputs.multi_modal_placeholders

# Last-mile processing of multimodal metadata and inputs.
if mm_positions:

# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
Expand All @@ -251,41 +228,30 @@ def process_inputs(
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
mm_positions,
mm_hashes,
decoder_inputs.multi_modal_placeholders,
decoder_inputs.multi_modal_hashes if self.use_hash else None,
)

# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
# modalities involved AND the model supports merged input processor.
if len(sorted_modalities) > 1 and precomputed_mm_inputs:

# modalities involved.
if len(sorted_modalities) > 1:
modality_order_dict = {
modality: order
for order, modality in enumerate(sorted_modalities)
}

# Sanity check to make sure each multimodal input has only one
# modality key.
for mm_input in precomputed_mm_inputs:
for mm_input in individual_mm_inputs:
assert len(mm_input.modalities) == 1

# Sort MultiModalKwags to match sorted_mm_positions
precomputed_mm_inputs = sorted(
precomputed_mm_inputs,
# Sort MultiModalKwargs to match sorted_mm_positions
sorted_mm_inputs = sorted(
individual_mm_inputs,
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])

# Apply mm input cache update and legacy input mapper if one exists.
sorted_mm_inputs = self.mm_input_cache_client.process_inputs(
mm_data=decoder_mm_data,
mm_hashes=sorted_mm_hashes,
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,
precomputed_mm_inputs=precomputed_mm_inputs,
)
else:
sorted_mm_inputs = None
sorted_mm_hashes = None
sorted_mm_positions = None
else:
sorted_mm_inputs = individual_mm_inputs

return EngineCoreRequest(
request_id=request_id,
Expand Down
Loading