3636from  vllm .model_executor .models .interfaces_base  import  (
3737    VllmModelForPooling , is_pooling_model , is_text_generation_model )
3838from  vllm .multimodal  import  MULTIMODAL_REGISTRY 
39- from  vllm .multimodal .inputs  import  MultiModalKwargs , PlaceholderRange 
39+ from  vllm .multimodal .inputs  import  (BatchedTensorInputs , MultiModalKwargs ,
40+                                     PlaceholderRange )
4041from  vllm .multimodal .utils  import  group_mm_inputs_by_modality 
4142from  vllm .pooling_params  import  PoolingParams 
4243from  vllm .sampling_params  import  SamplingType 
5152    make_kv_sharing_fast_prefill_attention_metadata ,
5253    make_local_attention_virtual_batches ,
5354    reorder_batch_to_split_decodes_and_prefills )
54- from  vllm .v1 .core .encoder_cache_manager  import  compute_encoder_budget 
5555from  vllm .v1 .kv_cache_interface  import  (AttentionSpec ,
5656                                        ChunkedLocalAttentionSpec ,
5757                                        FullAttentionSpec , KVCacheConfig ,
7373from  vllm .v1 .worker .lora_model_runner_mixin  import  LoRAModelRunnerMixin 
7474
7575from  ..sample .logits_processor  import  LogitsProcessorManager 
76- from  .utils  import  (bind_kv_cache , gather_mm_placeholders ,
76+ from  .utils  import  (MultiModalBudget ,  bind_kv_cache , gather_mm_placeholders ,
7777                    initialize_kv_cache_for_kv_sharing ,
7878                    sanity_check_mm_encoder_outputs , scatter_mm_placeholders )
7979
@@ -148,14 +148,6 @@ def __init__(
148148        self .mm_registry  =  MULTIMODAL_REGISTRY 
149149        self .uses_mrope  =  model_config .uses_mrope 
150150
151-         encoder_compute_budget , encoder_cache_size  =  compute_encoder_budget (
152-             model_config = model_config ,
153-             scheduler_config = scheduler_config ,
154-             mm_registry = self .mm_registry ,
155-         )
156-         self .max_num_encoder_input_tokens  =  encoder_compute_budget 
157-         self .encoder_cache_size  =  encoder_cache_size 
158- 
159151        # Sampler 
160152        self .sampler  =  Sampler (logprobs_mode = self .model_config .logprobs_mode )
161153
@@ -330,6 +322,14 @@ def __init__(
330322            self .kv_sharing_fast_prefill_logits_indices  =  torch .zeros (
331323                self .max_num_tokens , dtype = torch .int32 , device = self .device )
332324
325+         self .mm_budget  =  (MultiModalBudget (
326+             self .model_config ,
327+             self .scheduler_config ,
328+             self .mm_registry ,
329+             max_model_len = self .max_model_len ,
330+             max_num_reqs = self .max_num_reqs ,
331+         ) if  self .is_multimodal_model  else  None )
332+ 
333333        self .reorder_batch_threshold : Optional [int ] =  None 
334334
335335    def  _may_reorder_batch (self , scheduler_output : "SchedulerOutput" ) ->  None :
@@ -578,37 +578,33 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
578578        # Refresh batch metadata with any pending updates. 
579579        self .input_batch .refresh_metadata ()
580580
581-     def  _init_model_kwargs_for_multimodal_model (
581+     def  _extract_mm_kwargs (
582582        self ,
583-         scheduler_output : Optional ["SchedulerOutput" ] =  None ,
584-         num_reqs : int  =  - 1 ,
585-     ) ->  dict [str , Any ]:
586- 
587-         model_kwargs : dict [str , Any ] =  {}
588-         if  self .is_multimodal_raw_input_supported :
589-             # This model requires the raw multimodal data in input. 
583+         scheduler_output : "SchedulerOutput" ,
584+     ) ->  BatchedTensorInputs :
585+         if  self .is_multimodal_raw_input_supported :  # noqa: SIM102 
590586            if  scheduler_output :
591-                 multi_modal_kwargs_list  =  [] 
587+                 multi_modal_kwargs_list  =  list [ MultiModalKwargs ]() 
592588                for  req  in  scheduler_output .scheduled_new_reqs :
593589                    req_mm_inputs  =  req .mm_inputs 
594590                    if  not  isinstance (req_mm_inputs , list ):
595591                        req_mm_inputs  =  list (req_mm_inputs )
596592                    multi_modal_kwargs_list .extend (req_mm_inputs )
597-                 multi_modal_kwargs  =  MultiModalKwargs .batch (
598-                     multi_modal_kwargs_list )
599-             else :
600-                 # The only case where SchedulerOutput is None is for 
601-                 # a dummy run let's get some dummy data. 
602-                 dummy_data  =  [
603-                     self .mm_registry .get_decoder_dummy_data (
604-                         model_config = self .model_config ,
605-                         seq_len = 1 ).multi_modal_data  for  i  in  range (num_reqs )
606-                 ]
607-                 multi_modal_kwargs  =  MultiModalKwargs .batch (dummy_data )
608593
609-             model_kwargs . update ( multi_modal_kwargs )
594+                  return   MultiModalKwargs . batch ( multi_modal_kwargs_list )
610595
611-         return  model_kwargs 
596+         return  {}
597+ 
598+     def  _dummy_mm_kwargs (self , num_seqs : int ) ->  BatchedTensorInputs :
599+         if  self .is_multimodal_raw_input_supported :
600+             mm_budget  =  self .mm_budget 
601+             assert  mm_budget  is  not None 
602+ 
603+             dummy_modality , _  =  mm_budget .get_modality_with_max_tokens ()
604+ 
605+             return  self ._get_mm_dummy_batch (dummy_modality , num_seqs )
606+ 
607+         return  {}
612608
613609    def  _get_cumsum_and_arange (
614610        self ,
@@ -1517,27 +1513,26 @@ def execute_model(
15171513            # NOTE(woosuk): To unify token ids and soft tokens (vision 
15181514            # embeddings), we always use embeddings (rather than token ids) 
15191515            # as input to the multimodal model, even when the input is text. 
1520-             input_ids  =  self .input_ids [:num_scheduled_tokens ]
1521- 
1522-             model_kwargs  =  self ._init_model_kwargs_for_multimodal_model (
1523-                 scheduler_output = scheduler_output )
1524-             inputs_embeds  =  self .model .get_input_embeddings (
1525-                 input_ids = input_ids ,
1516+             inputs_embeds_scheduled  =  self .model .get_input_embeddings (
1517+                 input_ids = self .input_ids [:num_scheduled_tokens ],
15261518                multimodal_embeddings = mm_embeds  or  None ,
15271519            )
15281520
15291521            # TODO(woosuk): Avoid the copy. Optimize. 
1530-             self .inputs_embeds [:num_scheduled_tokens ].copy_ (inputs_embeds )
1531-             inputs_embeds  =  self .inputs_embeds [:num_input_tokens ]
1522+             self .inputs_embeds [:num_scheduled_tokens ].copy_ (
1523+                 inputs_embeds_scheduled )
1524+ 
15321525            input_ids  =  None 
1526+             inputs_embeds  =  self .inputs_embeds [:num_input_tokens ]
1527+             model_mm_kwargs  =  self ._extract_mm_kwargs (scheduler_output )
15331528        else :
15341529            # For text-only models, we use token ids as input. 
15351530            # While it is possible to use embeddings as input just like the 
15361531            # multimodal models, it is not desirable for performance since 
15371532            # then the embedding layer is not included in the CUDA graph. 
15381533            input_ids  =  self .input_ids [:num_input_tokens ]
15391534            inputs_embeds  =  None 
1540-             model_kwargs  =  {}
1535+             model_mm_kwargs  =  {}
15411536        if  self .uses_mrope :
15421537            positions  =  self .mrope_positions [:, :num_input_tokens ]
15431538        else :
@@ -1571,7 +1566,7 @@ def execute_model(
15711566                intermediate_tensors = intermediate_tensors ,
15721567                inputs_embeds = inputs_embeds ,
15731568                ** MultiModalKwargs .as_kwargs (
1574-                     model_kwargs ,
1569+                     model_mm_kwargs ,
15751570                    device = self .device ,
15761571                ),
15771572            )
@@ -2149,6 +2144,30 @@ def rand_input_ids() -> torch.Tensor:
21492144            yield 
21502145            input_ids .fill_ (0 )
21512146
2147+     def  _get_mm_dummy_batch (
2148+         self ,
2149+         modality : str ,
2150+         max_items_per_batch : int ,
2151+     ) ->  BatchedTensorInputs :
2152+         """Dummy data for profiling and precompiling multimodal models.""" 
2153+         dummy_decoder_data  =  self .mm_registry .get_decoder_dummy_data (
2154+             model_config = self .model_config ,
2155+             seq_len = self .max_num_tokens ,
2156+             mm_counts = {modality : 1 },
2157+         )
2158+         dummy_mm_data  =  dummy_decoder_data .multi_modal_data 
2159+ 
2160+         # Result in the maximum GPU consumption of the model 
2161+         dummy_mm_item  =  dummy_mm_data .get_item (modality = modality , item_index = 0 )
2162+         dummy_mm_kwargs  =  MultiModalKwargs .from_items ([dummy_mm_item ])
2163+ 
2164+         batched_dummy_mm_inputs  =  MultiModalKwargs .batch ([dummy_mm_kwargs ] * 
2165+                                                          max_items_per_batch )
2166+         return  MultiModalKwargs .as_kwargs (
2167+             batched_dummy_mm_inputs ,
2168+             device = self .device ,
2169+         )
2170+ 
21522171    @torch .inference_mode () 
21532172    def  _dummy_run (
21542173        self ,
@@ -2213,16 +2232,14 @@ def _dummy_run(
22132232
22142233        with  self .maybe_dummy_run_with_lora (self .lora_config ,
22152234                                            num_scheduled_tokens ):
2216-             model  =  self .model 
22172235            if  self .is_multimodal_model :
2218-                 model_kwargs  =  self ._init_model_kwargs_for_multimodal_model (
2219-                     num_reqs = num_reqs )
22202236                input_ids  =  None 
22212237                inputs_embeds  =  self .inputs_embeds [:num_tokens ]
2238+                 model_mm_kwargs  =  self ._dummy_mm_kwargs (num_reqs )
22222239            else :
22232240                input_ids  =  self .input_ids [:num_tokens ]
22242241                inputs_embeds  =  None 
2225-                 model_kwargs  =  {}
2242+                 model_mm_kwargs  =  {}
22262243
22272244            if  self .uses_mrope :
22282245                positions  =  self .mrope_positions [:, :num_tokens ]
@@ -2247,13 +2264,13 @@ def _dummy_run(
22472264                    self .vllm_config ,
22482265                    num_tokens = num_tokens ,
22492266                    num_tokens_across_dp = num_tokens_across_dp ):
2250-                 outputs  =  model (
2267+                 outputs  =  self . model (
22512268                    input_ids = input_ids ,
22522269                    positions = positions ,
22532270                    intermediate_tensors = intermediate_tensors ,
22542271                    inputs_embeds = inputs_embeds ,
22552272                    ** MultiModalKwargs .as_kwargs (
2256-                         model_kwargs ,
2273+                         model_mm_kwargs ,
22572274                        device = self .device ,
22582275                    ),
22592276                )
@@ -2423,75 +2440,51 @@ def _dummy_pooler_run(
24232440
24242441    def  profile_run (self ) ->  None :
24252442        # Profile with multimodal encoder & encoder cache. 
2426-         # TODO: handle encoder-decoder models once we support them. 
2427-         if  (self .is_multimodal_model  and  self .max_num_encoder_input_tokens  >  0 
2428-                 and  self .encoder_cache_size  >  0 ):
2429- 
2430-             # NOTE: Currently model is profiled with a single non-text 
2431-             # modality with the max possible input tokens even when 
2432-             # it supports multiple. 
2433-             max_tokens_by_modality_dict  =  self .mm_registry  \
2434-                 .get_max_tokens_per_item_by_nonzero_modality (self .model_config )
2435-             dummy_data_modality , max_tokens_per_mm_item  =  max (
2436-                 max_tokens_by_modality_dict .items (), key = lambda  item : item [1 ])
2437- 
2438-             # Check how many items of this modality can be supported by 
2439-             # the encoder budget. 
2440-             encoder_budget  =  min (self .max_num_encoder_input_tokens ,
2441-                                  self .encoder_cache_size )
2442- 
2443-             max_num_mm_items_encoder_budget  =  encoder_budget  //  \
2444-                 max_tokens_per_mm_item 
2445- 
2446-             # Check how many items of this modality can be supported by 
2447-             # the decoder budget. 
2448-             max_mm_items_per_req  =  self .mm_registry .get_mm_limits_per_prompt (
2449-                 self .model_config )[dummy_data_modality ]
2450- 
2451-             # NOTE: We do not consider max_num_batched_tokens on purpose 
2452-             # because the multimodal embeddings can be generated in advance 
2453-             # and chunked prefilled. 
2454-             max_num_mm_items_decoder_budget  =  self .max_num_reqs  *  \
2455-                 max_mm_items_per_req 
2456- 
2457-             max_num_mm_items  =  max (
2458-                 1 ,
2459-                 min (max_num_mm_items_encoder_budget ,
2460-                     max_num_mm_items_decoder_budget ))
2461- 
2462-             logger .info (
2463-                 "Encoder cache will be initialized with a budget of %s tokens," 
2464-                 " and profiled with %s %s items of the maximum feature size." ,
2465-                 encoder_budget , max_num_mm_items , dummy_data_modality )
2466- 
2467-             # Create dummy batch of multimodal inputs. 
2468-             dummy_mm_kwargs  =  self .mm_registry .get_decoder_dummy_data (
2469-                 model_config = self .model_config ,
2470-                 seq_len = max_tokens_per_mm_item ,
2471-                 mm_counts = {
2472-                     dummy_data_modality : 1 
2473-                 },
2474-             ).multi_modal_data 
2475- 
2476-             batched_dummy_mm_inputs  =  MultiModalKwargs .batch (
2477-                 [dummy_mm_kwargs ] *  max_num_mm_items ,
2478-                 pin_memory = self .pin_memory )
2479-             batched_dummy_mm_inputs  =  MultiModalKwargs .as_kwargs (
2480-                 batched_dummy_mm_inputs ,
2481-                 device = self .device ,
2482-             )
2443+         if  self .is_multimodal_model :
2444+             mm_budget  =  self .mm_budget 
2445+             assert  mm_budget  is  not None 
2446+ 
2447+             # TODO: handle encoder-decoder models once we support them. 
2448+             if  (encoder_budget  :=  mm_budget .get_encoder_budget ()) >  0 :
2449+                 # NOTE: Currently model is profiled with a single non-text 
2450+                 # modality with the max possible input tokens even when 
2451+                 # it supports multiple. 
2452+                 (
2453+                     dummy_modality ,
2454+                     max_tokens ,
2455+                 ) =  mm_budget .get_modality_with_max_tokens ()
2456+                 (
2457+                     max_mm_items_per_prompt ,
2458+                     max_mm_items_per_batch ,
2459+                 ) =  mm_budget .get_max_items (dummy_modality , max_tokens )
2460+ 
2461+                 logger .info (
2462+                     "Encoder cache will be initialized with a budget of " 
2463+                     "%s tokens, and profiled with %s %s items of the maximum " 
2464+                     "feature size." ,
2465+                     encoder_budget ,
2466+                     max_mm_items_per_batch ,
2467+                     dummy_modality ,
2468+                 )
24832469
2484-             # Run multimodal encoder. 
2485-             dummy_encoder_outputs  =  self .model .get_multimodal_embeddings (
2486-                 ** batched_dummy_mm_inputs )
2470+                 # Create dummy batch of multimodal inputs. 
2471+                 batched_dummy_mm_inputs  =  self ._get_mm_dummy_batch (
2472+                     dummy_modality ,
2473+                     max_mm_items_per_batch ,
2474+                 )
24872475
2488-             sanity_check_mm_encoder_outputs (
2489-                 dummy_encoder_outputs ,
2490-                 expected_num_items = max_num_mm_items ,
2491-             )
2476+                 # Run multimodal encoder. 
2477+                 dummy_encoder_outputs  =  self .model .get_multimodal_embeddings (
2478+                     ** batched_dummy_mm_inputs )
2479+ 
2480+                 sanity_check_mm_encoder_outputs (
2481+                     dummy_encoder_outputs ,
2482+                     expected_num_items = max_mm_items_per_batch ,
2483+                 )
24922484
2493-             # Cache the dummy encoder outputs. 
2494-             self .encoder_cache ["tmp" ] =  dict (enumerate (dummy_encoder_outputs ))
2485+                 # Cache the dummy encoder outputs. 
2486+                 self .encoder_cache ["tmp" ] =  dict (
2487+                     enumerate (dummy_encoder_outputs ))
24952488
24962489        # Add `is_profile` here to pre-allocate communication buffers 
24972490        hidden_states , last_hidden_states  \
0 commit comments