3636from vllm .model_executor .models .interfaces_base import (
3737 VllmModelForPooling , is_pooling_model , is_text_generation_model )
3838from vllm .multimodal import MULTIMODAL_REGISTRY
39- from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
39+ from vllm .multimodal .inputs import (BatchedTensorInputs , MultiModalKwargs ,
40+ PlaceholderRange )
4041from vllm .multimodal .utils import group_mm_inputs_by_modality
4142from vllm .pooling_params import PoolingParams
4243from vllm .sampling_params import SamplingType
5152 make_kv_sharing_fast_prefill_attention_metadata ,
5253 make_local_attention_virtual_batches ,
5354 reorder_batch_to_split_decodes_and_prefills )
54- from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
5555from vllm .v1 .kv_cache_interface import (AttentionSpec ,
5656 ChunkedLocalAttentionSpec ,
5757 FullAttentionSpec , KVCacheConfig ,
7373from vllm .v1 .worker .lora_model_runner_mixin import LoRAModelRunnerMixin
7474
7575from ..sample .logits_processor import LogitsProcessorManager
76- from .utils import (bind_kv_cache , gather_mm_placeholders ,
76+ from .utils import (MultiModalBudget , bind_kv_cache , gather_mm_placeholders ,
7777 initialize_kv_cache_for_kv_sharing ,
7878 sanity_check_mm_encoder_outputs , scatter_mm_placeholders )
7979
@@ -148,14 +148,6 @@ def __init__(
148148 self .mm_registry = MULTIMODAL_REGISTRY
149149 self .uses_mrope = model_config .uses_mrope
150150
151- encoder_compute_budget , encoder_cache_size = compute_encoder_budget (
152- model_config = model_config ,
153- scheduler_config = scheduler_config ,
154- mm_registry = self .mm_registry ,
155- )
156- self .max_num_encoder_input_tokens = encoder_compute_budget
157- self .encoder_cache_size = encoder_cache_size
158-
159151 # Sampler
160152 self .sampler = Sampler (logprobs_mode = self .model_config .logprobs_mode )
161153
@@ -330,6 +322,14 @@ def __init__(
330322 self .kv_sharing_fast_prefill_logits_indices = torch .zeros (
331323 self .max_num_tokens , dtype = torch .int32 , device = self .device )
332324
325+ self .mm_budget = (MultiModalBudget (
326+ self .model_config ,
327+ self .scheduler_config ,
328+ self .mm_registry ,
329+ max_model_len = self .max_model_len ,
330+ max_num_reqs = self .max_num_reqs ,
331+ ) if self .is_multimodal_model else None )
332+
333333 self .reorder_batch_threshold : Optional [int ] = None
334334
335335 def _may_reorder_batch (self , scheduler_output : "SchedulerOutput" ) -> None :
@@ -578,37 +578,33 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
578578 # Refresh batch metadata with any pending updates.
579579 self .input_batch .refresh_metadata ()
580580
581- def _init_model_kwargs_for_multimodal_model (
581+ def _extract_mm_kwargs (
582582 self ,
583- scheduler_output : Optional ["SchedulerOutput" ] = None ,
584- num_reqs : int = - 1 ,
585- ) -> dict [str , Any ]:
586-
587- model_kwargs : dict [str , Any ] = {}
588- if self .is_multimodal_raw_input_supported :
589- # This model requires the raw multimodal data in input.
583+ scheduler_output : "SchedulerOutput" ,
584+ ) -> BatchedTensorInputs :
585+ if self .is_multimodal_raw_input_supported : # noqa: SIM102
590586 if scheduler_output :
591- multi_modal_kwargs_list = []
587+ multi_modal_kwargs_list = list [ MultiModalKwargs ]()
592588 for req in scheduler_output .scheduled_new_reqs :
593589 req_mm_inputs = req .mm_inputs
594590 if not isinstance (req_mm_inputs , list ):
595591 req_mm_inputs = list (req_mm_inputs )
596592 multi_modal_kwargs_list .extend (req_mm_inputs )
597- multi_modal_kwargs = MultiModalKwargs .batch (
598- multi_modal_kwargs_list )
599- else :
600- # The only case where SchedulerOutput is None is for
601- # a dummy run let's get some dummy data.
602- dummy_data = [
603- self .mm_registry .get_decoder_dummy_data (
604- model_config = self .model_config ,
605- seq_len = 1 ).multi_modal_data for i in range (num_reqs )
606- ]
607- multi_modal_kwargs = MultiModalKwargs .batch (dummy_data )
608593
609- model_kwargs . update ( multi_modal_kwargs )
594+ return MultiModalKwargs . batch ( multi_modal_kwargs_list )
610595
611- return model_kwargs
596+ return {}
597+
598+ def _dummy_mm_kwargs (self , num_seqs : int ) -> BatchedTensorInputs :
599+ if self .is_multimodal_raw_input_supported :
600+ mm_budget = self .mm_budget
601+ assert mm_budget is not None
602+
603+ dummy_modality , _ = mm_budget .get_modality_with_max_tokens ()
604+
605+ return self ._get_mm_dummy_batch (dummy_modality , num_seqs )
606+
607+ return {}
612608
613609 def _get_cumsum_and_arange (
614610 self ,
@@ -1517,27 +1513,26 @@ def execute_model(
15171513 # NOTE(woosuk): To unify token ids and soft tokens (vision
15181514 # embeddings), we always use embeddings (rather than token ids)
15191515 # as input to the multimodal model, even when the input is text.
1520- input_ids = self .input_ids [:num_scheduled_tokens ]
1521-
1522- model_kwargs = self ._init_model_kwargs_for_multimodal_model (
1523- scheduler_output = scheduler_output )
1524- inputs_embeds = self .model .get_input_embeddings (
1525- input_ids = input_ids ,
1516+ inputs_embeds_scheduled = self .model .get_input_embeddings (
1517+ input_ids = self .input_ids [:num_scheduled_tokens ],
15261518 multimodal_embeddings = mm_embeds or None ,
15271519 )
15281520
15291521 # TODO(woosuk): Avoid the copy. Optimize.
1530- self .inputs_embeds [:num_scheduled_tokens ].copy_ (inputs_embeds )
1531- inputs_embeds = self .inputs_embeds [:num_input_tokens ]
1522+ self .inputs_embeds [:num_scheduled_tokens ].copy_ (
1523+ inputs_embeds_scheduled )
1524+
15321525 input_ids = None
1526+ inputs_embeds = self .inputs_embeds [:num_input_tokens ]
1527+ model_mm_kwargs = self ._extract_mm_kwargs (scheduler_output )
15331528 else :
15341529 # For text-only models, we use token ids as input.
15351530 # While it is possible to use embeddings as input just like the
15361531 # multimodal models, it is not desirable for performance since
15371532 # then the embedding layer is not included in the CUDA graph.
15381533 input_ids = self .input_ids [:num_input_tokens ]
15391534 inputs_embeds = None
1540- model_kwargs = {}
1535+ model_mm_kwargs = {}
15411536 if self .uses_mrope :
15421537 positions = self .mrope_positions [:, :num_input_tokens ]
15431538 else :
@@ -1571,7 +1566,7 @@ def execute_model(
15711566 intermediate_tensors = intermediate_tensors ,
15721567 inputs_embeds = inputs_embeds ,
15731568 ** MultiModalKwargs .as_kwargs (
1574- model_kwargs ,
1569+ model_mm_kwargs ,
15751570 device = self .device ,
15761571 ),
15771572 )
@@ -2149,6 +2144,30 @@ def rand_input_ids() -> torch.Tensor:
21492144 yield
21502145 input_ids .fill_ (0 )
21512146
2147+ def _get_mm_dummy_batch (
2148+ self ,
2149+ modality : str ,
2150+ max_items_per_batch : int ,
2151+ ) -> BatchedTensorInputs :
2152+ """Dummy data for profiling and precompiling multimodal models."""
2153+ dummy_decoder_data = self .mm_registry .get_decoder_dummy_data (
2154+ model_config = self .model_config ,
2155+ seq_len = self .max_num_tokens ,
2156+ mm_counts = {modality : 1 },
2157+ )
2158+ dummy_mm_data = dummy_decoder_data .multi_modal_data
2159+
2160+ # Result in the maximum GPU consumption of the model
2161+ dummy_mm_item = dummy_mm_data .get_item (modality = modality , item_index = 0 )
2162+ dummy_mm_kwargs = MultiModalKwargs .from_items ([dummy_mm_item ])
2163+
2164+ batched_dummy_mm_inputs = MultiModalKwargs .batch ([dummy_mm_kwargs ] *
2165+ max_items_per_batch )
2166+ return MultiModalKwargs .as_kwargs (
2167+ batched_dummy_mm_inputs ,
2168+ device = self .device ,
2169+ )
2170+
21522171 @torch .inference_mode ()
21532172 def _dummy_run (
21542173 self ,
@@ -2213,16 +2232,14 @@ def _dummy_run(
22132232
22142233 with self .maybe_dummy_run_with_lora (self .lora_config ,
22152234 num_scheduled_tokens ):
2216- model = self .model
22172235 if self .is_multimodal_model :
2218- model_kwargs = self ._init_model_kwargs_for_multimodal_model (
2219- num_reqs = num_reqs )
22202236 input_ids = None
22212237 inputs_embeds = self .inputs_embeds [:num_tokens ]
2238+ model_mm_kwargs = self ._dummy_mm_kwargs (num_reqs )
22222239 else :
22232240 input_ids = self .input_ids [:num_tokens ]
22242241 inputs_embeds = None
2225- model_kwargs = {}
2242+ model_mm_kwargs = {}
22262243
22272244 if self .uses_mrope :
22282245 positions = self .mrope_positions [:, :num_tokens ]
@@ -2247,13 +2264,13 @@ def _dummy_run(
22472264 self .vllm_config ,
22482265 num_tokens = num_tokens ,
22492266 num_tokens_across_dp = num_tokens_across_dp ):
2250- outputs = model (
2267+ outputs = self . model (
22512268 input_ids = input_ids ,
22522269 positions = positions ,
22532270 intermediate_tensors = intermediate_tensors ,
22542271 inputs_embeds = inputs_embeds ,
22552272 ** MultiModalKwargs .as_kwargs (
2256- model_kwargs ,
2273+ model_mm_kwargs ,
22572274 device = self .device ,
22582275 ),
22592276 )
@@ -2423,75 +2440,51 @@ def _dummy_pooler_run(
24232440
24242441 def profile_run (self ) -> None :
24252442 # Profile with multimodal encoder & encoder cache.
2426- # TODO: handle encoder-decoder models once we support them.
2427- if (self .is_multimodal_model and self .max_num_encoder_input_tokens > 0
2428- and self .encoder_cache_size > 0 ):
2429-
2430- # NOTE: Currently model is profiled with a single non-text
2431- # modality with the max possible input tokens even when
2432- # it supports multiple.
2433- max_tokens_by_modality_dict = self .mm_registry \
2434- .get_max_tokens_per_item_by_nonzero_modality (self .model_config )
2435- dummy_data_modality , max_tokens_per_mm_item = max (
2436- max_tokens_by_modality_dict .items (), key = lambda item : item [1 ])
2437-
2438- # Check how many items of this modality can be supported by
2439- # the encoder budget.
2440- encoder_budget = min (self .max_num_encoder_input_tokens ,
2441- self .encoder_cache_size )
2442-
2443- max_num_mm_items_encoder_budget = encoder_budget // \
2444- max_tokens_per_mm_item
2445-
2446- # Check how many items of this modality can be supported by
2447- # the decoder budget.
2448- max_mm_items_per_req = self .mm_registry .get_mm_limits_per_prompt (
2449- self .model_config )[dummy_data_modality ]
2450-
2451- # NOTE: We do not consider max_num_batched_tokens on purpose
2452- # because the multimodal embeddings can be generated in advance
2453- # and chunked prefilled.
2454- max_num_mm_items_decoder_budget = self .max_num_reqs * \
2455- max_mm_items_per_req
2456-
2457- max_num_mm_items = max (
2458- 1 ,
2459- min (max_num_mm_items_encoder_budget ,
2460- max_num_mm_items_decoder_budget ))
2461-
2462- logger .info (
2463- "Encoder cache will be initialized with a budget of %s tokens,"
2464- " and profiled with %s %s items of the maximum feature size." ,
2465- encoder_budget , max_num_mm_items , dummy_data_modality )
2466-
2467- # Create dummy batch of multimodal inputs.
2468- dummy_mm_kwargs = self .mm_registry .get_decoder_dummy_data (
2469- model_config = self .model_config ,
2470- seq_len = max_tokens_per_mm_item ,
2471- mm_counts = {
2472- dummy_data_modality : 1
2473- },
2474- ).multi_modal_data
2475-
2476- batched_dummy_mm_inputs = MultiModalKwargs .batch (
2477- [dummy_mm_kwargs ] * max_num_mm_items ,
2478- pin_memory = self .pin_memory )
2479- batched_dummy_mm_inputs = MultiModalKwargs .as_kwargs (
2480- batched_dummy_mm_inputs ,
2481- device = self .device ,
2482- )
2443+ if self .is_multimodal_model :
2444+ mm_budget = self .mm_budget
2445+ assert mm_budget is not None
2446+
2447+ # TODO: handle encoder-decoder models once we support them.
2448+ if (encoder_budget := mm_budget .get_encoder_budget ()) > 0 :
2449+ # NOTE: Currently model is profiled with a single non-text
2450+ # modality with the max possible input tokens even when
2451+ # it supports multiple.
2452+ (
2453+ dummy_modality ,
2454+ max_tokens ,
2455+ ) = mm_budget .get_modality_with_max_tokens ()
2456+ (
2457+ max_mm_items_per_prompt ,
2458+ max_mm_items_per_batch ,
2459+ ) = mm_budget .get_max_items (dummy_modality , max_tokens )
2460+
2461+ logger .info (
2462+ "Encoder cache will be initialized with a budget of "
2463+ "%s tokens, and profiled with %s %s items of the maximum "
2464+ "feature size." ,
2465+ encoder_budget ,
2466+ max_mm_items_per_batch ,
2467+ dummy_modality ,
2468+ )
24832469
2484- # Run multimodal encoder.
2485- dummy_encoder_outputs = self .model .get_multimodal_embeddings (
2486- ** batched_dummy_mm_inputs )
2470+ # Create dummy batch of multimodal inputs.
2471+ batched_dummy_mm_inputs = self ._get_mm_dummy_batch (
2472+ dummy_modality ,
2473+ max_mm_items_per_batch ,
2474+ )
24872475
2488- sanity_check_mm_encoder_outputs (
2489- dummy_encoder_outputs ,
2490- expected_num_items = max_num_mm_items ,
2491- )
2476+ # Run multimodal encoder.
2477+ dummy_encoder_outputs = self .model .get_multimodal_embeddings (
2478+ ** batched_dummy_mm_inputs )
2479+
2480+ sanity_check_mm_encoder_outputs (
2481+ dummy_encoder_outputs ,
2482+ expected_num_items = max_mm_items_per_batch ,
2483+ )
24922484
2493- # Cache the dummy encoder outputs.
2494- self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
2485+ # Cache the dummy encoder outputs.
2486+ self .encoder_cache ["tmp" ] = dict (
2487+ enumerate (dummy_encoder_outputs ))
24952488
24962489 # Add `is_profile` here to pre-allocate communication buffers
24972490 hidden_states , last_hidden_states \
0 commit comments