diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index f56cff292b68..af35e43d825a 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -379,6 +379,7 @@ def _prompt_to_llm_inputs( multi_modal_data, mm_processor_kwargs, lora_request=lora_request, + return_mm_hashes=return_mm_hashes, ) prompt_token_ids = self._tokenize_prompt( @@ -401,6 +402,7 @@ async def _prompt_to_llm_inputs_async( prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, ) -> SingletonInputs: """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(prompt) @@ -431,6 +433,7 @@ async def _prompt_to_llm_inputs_async( multi_modal_data, mm_processor_kwargs, lora_request=lora_request, + return_mm_hashes=return_mm_hashes, ) return token_inputs( @@ -452,6 +455,7 @@ async def _prompt_to_llm_inputs_async( multi_modal_data, mm_processor_kwargs, lora_request=lora_request, + return_mm_hashes=return_mm_hashes, ) prompt_token_ids = await self._tokenize_prompt_async( @@ -726,6 +730,7 @@ def _process_decoder_only_prompt( prompt, request_id=request_id, lora_request=lora_request, + return_mm_hashes=return_mm_hashes, ) return self._build_decoder_only_llm_inputs( @@ -746,6 +751,7 @@ async def _process_decoder_only_prompt_async( prompt, request_id=request_id, lora_request=lora_request, + return_mm_hashes=return_mm_hashes, ) return self._build_decoder_only_llm_inputs( diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index cd29c2d7d57c..3699779b3a0f 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -52,7 +52,7 @@ class EngineCoreRequest( # Detokenizer, but set to None when it is added to EngineCoreClient. prompt: Optional[str] prompt_token_ids: list[int] - mm_inputs: Optional[list[Optional[MultiModalKwargs]]] + mm_inputs: Optional[list[MultiModalKwargs]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] sampling_params: SamplingParams diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index e2dda73ba429..61a55d2499bd 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -1,131 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Optional - -from vllm.config import ModelConfig from vllm.envs import VLLM_MM_INPUT_CACHE_GIB -from vllm.logger import init_logger -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, - MultiModalKwargs, MultiModalRegistry) +from vllm.multimodal import MultiModalKwargs from vllm.multimodal.processing import ProcessingCache -logger = init_logger(__name__) - # The idea of multimodal preprocessing caching is based on having a client and # a server, where the client executes in the frontend process (=P0) and the # server in the core process (=P1). # # -- Client: -# - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs. -# - Perform caching of the generated MultiModalKwargs. -# - This client can be deprecated once all mutimodal models migrate to use -# merged preprocessor with built-in caching functionality. +# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs +# with built-in caching functionality, with mm_hash as its identifier. # # -- Server: -# - Perform caching of the received MultiModalKwargs. +# - MMInputCacheServer to perform caching of the received MultiModalKwargs. # -# The caching for both client and server is mirrored/similar, and this allows us +# The caching for both client and server is mirrored, and this allows us # to avoid the serialization of "mm_inputs" (like pixel values) between -# client (=P0) and server (=P1) processes. +# client (=P0) and server (=P1) processes if the mm_hash is found in the client +# cache. # Both Client and Server must use the same cache size # (to perform mirrored caching). This cache size is set by the environment # variable VLLM_MM_INPUT_CACHE_GIB. -# TODO(ywang96): Deprecate this class once all multimodal models migrate to use -# merged preprocessor with built-in caching functionality. -class MMInputCacheClient: - - def __init__( - self, - model_config: ModelConfig, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - self.model_config = model_config - self.mm_registry = mm_registry - self.multi_modal_input_mapper = mm_registry.create_input_mapper( - model_config) - self.mm_registry.init_mm_limits_per_prompt(model_config) - - # Init cache - self.use_cache = not model_config.disable_mm_preprocessor_cache - self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, - MultiModalKwargs) - - # DEBUG: Set to None to disable - self.mm_debug_cache_hit_ratio_steps = None - self.mm_debug_cache_hits = 0 - self.mm_debug_cache_total = 0 - - def cache_hit_ratio(self, steps): - total = self.mm_debug_cache_total - - if total > 0 and total % steps == 0: - logger.debug("MMInputMapper: cache_hit_ratio = %.2f ", - self.mm_debug_cache_hits / total) - - # NOTE: process_inputs only supports image inputs since all multimodal - # models with other modalities have migrated to use merged preprocessor. - def process_inputs( - self, - mm_data: MultiModalDataDict, - mm_hashes: Optional[list[str]], - mm_processor_kwargs: Optional[dict[str, Any]], - precomputed_mm_inputs: Optional[list[MultiModalKwargs]], - ) -> list[Optional[MultiModalKwargs]]: - if precomputed_mm_inputs is None: - image_inputs = mm_data["image"] - if not isinstance(image_inputs, list): - image_inputs = [image_inputs] - num_inputs = len(image_inputs) - else: - num_inputs = len(precomputed_mm_inputs) - - # Sanity - if self.use_cache: - assert mm_hashes is not None - assert num_inputs == len(mm_hashes) - - # Process each image input separately, so that later we can schedule - # them in a fine-grained manner. - # Apply caching (if enabled) and reuse precomputed inputs (if provided) - ret_inputs: list[Optional[MultiModalKwargs]] = [] - for input_id in range(num_inputs): - if self.mm_debug_cache_hit_ratio_steps is not None: - self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps) - - mm_input = None - if self.use_cache: - assert mm_hashes is not None - mm_hash = mm_hashes[input_id] - mm_input = self.mm_cache.get(mm_hash) - - self.mm_debug_cache_total += 1 - if mm_input is None: - if precomputed_mm_inputs is not None: - # Reuse precomputed input (for merged preprocessor) - mm_input = precomputed_mm_inputs[input_id] - else: - # Apply legacy input_mapper - mm_input = self.multi_modal_input_mapper( - {"image": [image_inputs[input_id]]}, - mm_processor_kwargs=mm_processor_kwargs, - ) - - if self.use_cache: - # Add to cache - assert mm_hash is not None - self.mm_cache[mm_hash] = mm_input - else: - self.mm_debug_cache_hits += 1 - mm_input = None # Avoids sending mm_input to Server - - ret_inputs.append(mm_input) - - return ret_inputs - - class MMInputCacheServer: def __init__(self, model_config): @@ -135,9 +34,9 @@ def __init__(self, model_config): def get_and_update( self, - mm_inputs: list[Optional[MultiModalKwargs]], + mm_inputs: list[MultiModalKwargs], mm_hashes: list[str], - ) -> list[Optional[MultiModalKwargs]]: + ) -> list[MultiModalKwargs]: assert len(mm_inputs) == len(mm_hashes) if not self.use_cache: @@ -147,8 +46,7 @@ def get_and_update( for mm_input, mm_hash in zip(mm_inputs, mm_hashes): assert mm_hash is not None if mm_input is None: - mm_input = self.mm_cache.get(mm_hash) - assert mm_input is not None + mm_input = self.mm_cache[mm_hash] else: self.mm_cache[mm_hash] = mm_input diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 663e1e36f756..4e9e5506bb58 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -11,15 +11,15 @@ from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalHasher, - MultiModalKwargs, MultiModalRegistry) +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, + MultiModalRegistry) +from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import merge_and_sort_multimodal_metadata from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.structured_output.utils import validate_structured_output_request @@ -45,11 +45,6 @@ def __init__( self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer, mm_registry) - self.input_processor = input_registry.create_input_processor( - self.model_config) - - # Multi-modal (huggingface) input mapper - self.mm_input_cache_client = MMInputCacheClient(self.model_config) # Multi-modal hasher (for images) self.use_hash = ( @@ -171,7 +166,7 @@ def process_inputs( # 2. For multimodal models with a merged preprocessor, preprocess # multimodal data and expand prompt token ids accordingly. # 3. Apply prompt adapter to prompt token ids if one exists. - preprocessed_inputs = self.input_preprocessor.preprocess( + processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, request_id=request_id, lora_request=lora_request, @@ -180,10 +175,6 @@ def process_inputs( ) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - # Process prompt and prompt token ids. - # Only applicable to multimodal models with legacy input processor. - processed_inputs = self.input_processor(preprocessed_inputs) - self._validate_model_inputs(processed_inputs, lora_request) if is_encoder_decoder_inputs(processed_inputs): @@ -212,36 +203,22 @@ def process_inputs( self.tokenizer.get_lora_tokenizer(lora_request)) # Multimodal related. - # Compute MM hashes (if enabled) - mm_hashes = None - if self.use_hash: - # Use mm_hashes from processed inputs if the model has merged - # input processor. - if decoder_inputs.multi_modal_hashes: - mm_hashes = decoder_inputs.multi_modal_hashes - # Fallback to using MultiModalHasher directly. - else: - mm_hashes = MultiModalHasher.hash_prompt_mm_data(prompt) + sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None + sorted_mm_positions: Optional[list[PlaceholderRange]] = None + sorted_mm_hashes: Optional[list[str]] = None + if (decoder_mm_inputs := decoder_inputs.multi_modal_data): + assert isinstance(decoder_mm_inputs, MultiModalKwargs) - # For merged preprocessor, mm_data is already mm_inputs - precomputed_mm_inputs: Optional[list[MultiModalKwargs]] = None - decoder_mm_data = decoder_inputs.multi_modal_data - if isinstance(decoder_mm_data, MultiModalKwargs): - # The output of merged multi-modal processor (`decoder_mm_data`) + # The output of merged multi-modal processor (`decoder_mm_inputs`) # contains the kwargs for all items from all modalities. # This code separates them so that there is one set of kwargs # per item per modality. - precomputed_mm_inputs = [ + individual_mm_inputs = [ MultiModalKwargs.from_items([item]) - for modality in decoder_mm_data.modalities - for item in decoder_mm_data.get_items(modality) + for modality in decoder_mm_inputs.modalities + for item in decoder_mm_inputs.get_items(modality) ] - mm_positions = decoder_inputs.multi_modal_placeholders - - # Last-mile processing of multimodal metadata and inputs. - if mm_positions: - # Merge and flatten multimodal placeholders, hashes and inputs # from dictionaries to lists, and sort them by each item's position # in the input sequence. @@ -251,14 +228,13 @@ def process_inputs( sorted_mm_positions, sorted_mm_hashes, ) = merge_and_sort_multimodal_metadata( - mm_positions, - mm_hashes, + decoder_inputs.multi_modal_placeholders, + decoder_inputs.multi_modal_hashes if self.use_hash else None, ) # NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple - # modalities involved AND the model supports merged input processor. - if len(sorted_modalities) > 1 and precomputed_mm_inputs: - + # modalities involved. + if len(sorted_modalities) > 1: modality_order_dict = { modality: order for order, modality in enumerate(sorted_modalities) @@ -266,26 +242,16 @@ def process_inputs( # Sanity check to make sure each multimodal input has only one # modality key. - for mm_input in precomputed_mm_inputs: + for mm_input in individual_mm_inputs: assert len(mm_input.modalities) == 1 - # Sort MultiModalKwags to match sorted_mm_positions - precomputed_mm_inputs = sorted( - precomputed_mm_inputs, + # Sort MultiModalKwargs to match sorted_mm_positions + sorted_mm_inputs = sorted( + individual_mm_inputs, key=lambda mm_input: modality_order_dict[list( mm_input.modalities)[0]]) - - # Apply mm input cache update and legacy input mapper if one exists. - sorted_mm_inputs = self.mm_input_cache_client.process_inputs( - mm_data=decoder_mm_data, - mm_hashes=sorted_mm_hashes, - mm_processor_kwargs=decoder_inputs.mm_processor_kwargs, - precomputed_mm_inputs=precomputed_mm_inputs, - ) - else: - sorted_mm_inputs = None - sorted_mm_hashes = None - sorted_mm_positions = None + else: + sorted_mm_inputs = individual_mm_inputs return EngineCoreRequest( request_id=request_id, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8dd7521ff49a..4c82da7e1a8b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -29,7 +29,6 @@ is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, @@ -132,14 +131,6 @@ def __init__( self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope - if self.is_multimodal_model: - # NOTE: Initialized client is only used for processing dummy - # multimodal data into multimodal kwargs for GPU memory profiling. - # Only applicable to multimodal models with legacy input mapper. - self.mm_input_mapper_profiling = MMInputCacheClient( - self.model_config) - self.mm_input_mapper_profiling.use_cache = False - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( model_config=model_config, scheduler_config=scheduler_config, @@ -1358,32 +1349,18 @@ def profile_run(self) -> None: mm_registry=self.mm_registry, ) dummy_mm_data = dummy_request_data.multi_modal_data + if not isinstance(dummy_mm_data, MultiModalKwargs): + # TODO: Delete this check once input mapper is fully removed. + raise RuntimeError( + "Legacy input mapper is not supported in V1") - # Dummy data definition in V0 may contain multiple multimodal items + # Dummy data definition may contain multiple multimodal items # (e.g, multiple images) for a single request, therefore here we # always replicate first item by max_num_mm_items times since in V1 # they are scheduled to be processed separately. - - # Case when models have a merged processor, their dummy data is - # already batched `MultiModalKwargs`, therefore we take the first - # `MultiModalKwargsItem` from the desired modality to profile on. - if isinstance(dummy_mm_data, MultiModalKwargs): - dummy_mm_item = dummy_mm_data.get_item( - modality=dummy_data_modality, item_index=0) - dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) - - # Case when models have dummy data explicitly defined as - # `MultiModalDataDict`, so they need to be processed through input - # mapper. - # TODO (ywang96): deprecate this path once merged processor is - # supported on all models. - else: - mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs( - mm_data=dummy_mm_data, - mm_hashes=None, - mm_processor_kwargs=None, - precomputed_mm_inputs=None) - dummy_mm_kwargs = mm_kwargs_list[0] + dummy_mm_item = dummy_mm_data.get_item( + modality=dummy_data_modality, item_index=0) + dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) batched_dummy_mm_inputs = MultiModalKwargs.batch( [dummy_mm_kwargs] * max_num_mm_items)