1515from vllm .attention .layer import Attention
1616from vllm .config import VllmConfig
1717from vllm .forward_context import get_forward_context , set_forward_context
18+ from vllm .inputs import INPUT_REGISTRY
1819from vllm .logger import init_logger
1920from vllm .model_executor .model_loader import get_model
21+ from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalKwargs
22+ from vllm .multimodal .utils import group_mm_inputs_by_modality
2023from vllm .sampling_params import SamplingType
2124from vllm .utils import LayerBlockType , cdiv , is_pin_memory_available
2225from vllm .v1 .attention .backends .pallas import (NUM_KV_PAGES_PER_BLOCK ,
2326 NUM_QUERIES_PER_BLOCK ,
2427 PallasAttentionBackend ,
2528 PallasMetadata )
29+ from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
2630from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
2731 KVCacheSpec )
2832from vllm .v1 .outputs import LogprobsTensors , ModelRunnerOutput
@@ -72,8 +76,10 @@ def __init__(
7276 self .block_size = cache_config .block_size
7377 self .max_model_len = model_config .max_model_len
7478 self .max_num_blocks_per_req = cdiv (self .max_model_len , self .block_size )
75- self .max_num_tokens = scheduler_config .max_num_batched_tokens
76- self .max_num_reqs = scheduler_config .max_num_seqs
79+ self .max_num_tokens = _get_padded_number (
80+ scheduler_config .max_num_batched_tokens , NUM_QUERIES_PER_BLOCK )
81+ self .max_num_reqs = _get_padded_number (scheduler_config .max_num_seqs ,
82+ NUM_QUERIES_PER_BLOCK )
7783
7884 # Model-related.
7985 self .num_attn_layers = model_config .get_num_layers_by_block_type (
@@ -84,25 +90,38 @@ def __init__(
8490 self .head_size = model_config .get_head_size ()
8591 self .hidden_size = model_config .get_hidden_size ()
8692
93+ # Multi-modal data support
94+ self .input_registry = INPUT_REGISTRY
95+ self .mm_registry = MULTIMODAL_REGISTRY
96+ self .uses_mrope = model_config .uses_mrope
97+ # TODO: Support M-RoPE (e.g, Qwen2-VL)
98+ assert not self .uses_mrope , "TPU does not support M-RoPE yet."
99+
100+ encoder_compute_budget , encoder_cache_size = compute_encoder_budget (
101+ model_config = model_config ,
102+ scheduler_config = scheduler_config ,
103+ )
104+ self .max_num_encoder_input_tokens = encoder_compute_budget
105+ self .encoder_cache_size = encoder_cache_size
106+
107+ # Lazy initialization
108+ # self.model: nn.Module # Set after load_model
109+ self .kv_caches : list [torch .Tensor ] = []
110+ # req_id -> (input_id -> encoder_output)
111+ self .encoder_cache : dict [str , dict [int , torch .Tensor ]] = {}
112+
113+ # Request states.
114+ self .requests : dict [str , CachedRequestState ] = {}
87115 # Persistent batch.
88116 self .input_batch = InputBatch (
89117 max_num_reqs = self .max_num_reqs ,
90118 max_model_len = self .max_model_len ,
91119 max_num_blocks_per_req = self .max_num_blocks_per_req ,
92120 device = self .device ,
93121 pin_memory = self .pin_memory ,
94- vocab_size = self . model_config .get_vocab_size (),
122+ vocab_size = model_config .get_vocab_size (),
95123 )
96124
97- # Request states.
98- self .requests : dict [str , CachedRequestState ] = {}
99-
100- # req_id -> (input_id -> encoder_output)
101- self .encoder_cache : dict [str , dict [int , torch .Tensor ]] = {}
102-
103- # KV caches for forward pass
104- self .kv_caches : list [tuple [torch .Tensor , torch .Tensor ]] = []
105-
106125 # Cached torch/numpy tensor
107126 # The pytorch tensor and numpy array share the same buffer.
108127 # Sometimes the numpy op is faster so we create both.
@@ -164,6 +183,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
164183 # Remove finished requests from the cached states.
165184 for req_id in scheduler_output .finished_req_ids :
166185 self .requests .pop (req_id , None )
186+ self .encoder_cache .pop (req_id , None )
167187
168188 # Remove the finished requests from the persistent batch.
169189 # NOTE(woosuk): There could be an edge case where finished_req_ids and
@@ -177,6 +197,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
177197 if req_index is not None :
178198 removed_req_indices .append (req_index )
179199
200+ # Free the cached encoder outputs.
201+ for req_id , input_id in scheduler_output .free_encoder_input_ids :
202+ encoder_outputs = self .encoder_cache .get (req_id )
203+ if encoder_outputs is not None :
204+ encoder_outputs .pop (input_id , None )
205+ if not encoder_outputs :
206+ self .encoder_cache .pop (req_id , None )
207+
180208 # Remove the unscheduled requests from the persistent batch.
181209 # NOTE(woosuk): The unscheduled requests are either preempted requests
182210 # or running requests that are not scheduled in this step. We remove
@@ -426,6 +454,92 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
426454 logits_indices = query_start_loc [1 :] - 1
427455 return attn_metadata , logits_indices
428456
457+ def _execute_encoder (self , scheduler_output : "SchedulerOutput" ):
458+ scheduled_encoder_inputs = scheduler_output .scheduled_encoder_inputs
459+ if not scheduled_encoder_inputs :
460+ return
461+
462+ # Batch the multi-modal inputs.
463+ mm_inputs : list [MultiModalKwargs ] = []
464+ req_input_ids : list [tuple [str , int ]] = []
465+ for req_id , encoder_input_ids in scheduled_encoder_inputs .items ():
466+ req_state = self .requests [req_id ]
467+ for input_id in encoder_input_ids :
468+ mm_inputs .append (req_state .mm_inputs [input_id ])
469+ req_input_ids .append ((req_id , input_id ))
470+
471+ # Batch mm inputs as much as we can: if a request in the batch has
472+ # multiple modalities or a different modality than the previous one,
473+ # we process it separately to preserve item order.
474+ # FIXME(ywang96): This is a hacky way to deal with multiple modalities
475+ # in the same batch while still being able to benefit from batching
476+ # multimodal inputs. The proper solution should be reordering the
477+ # encoder outputs.
478+ grouped_mm_inputs_list = group_mm_inputs_by_modality (mm_inputs )
479+
480+ encoder_outputs = []
481+ for grouped_mm_inputs in grouped_mm_inputs_list :
482+ batched_mm_inputs = MultiModalKwargs .batch (grouped_mm_inputs )
483+ batched_mm_inputs = MultiModalKwargs .as_kwargs (batched_mm_inputs ,
484+ device = self .device )
485+
486+ # Run the encoder.
487+ # `curr_group_outputs` is either of the following:
488+ # 1. A tensor of shape (num_items, feature_size, hidden_size)
489+ # in case feature_size is fixed across all multimodal items.
490+ # 2. A list or tuple (length: num_items) of tensors, each of shape
491+ # (feature_size, hidden_size) in case the feature size is dynamic
492+ # depending on the input multimodal items.
493+ curr_group_outputs = self .model .get_multimodal_embeddings (
494+ ** batched_mm_inputs )
495+
496+ for output in curr_group_outputs :
497+ encoder_outputs .append (output )
498+
499+ # Cache the encoder outputs.
500+ for (req_id , input_id ), output in zip (req_input_ids , encoder_outputs ):
501+ if req_id not in self .encoder_cache :
502+ self .encoder_cache [req_id ] = {}
503+ self .encoder_cache [req_id ][input_id ] = output
504+
505+ def _gather_encoder_outputs (
506+ self ,
507+ scheduler_output : "SchedulerOutput" ,
508+ ) -> list [torch .Tensor ]:
509+ encoder_outputs : list [torch .Tensor ] = []
510+ for req_id in self .input_batch .req_ids :
511+ num_scheduled_tokens = scheduler_output .num_scheduled_tokens [
512+ req_id ]
513+ req_state = self .requests [req_id ]
514+ num_computed_tokens = req_state .num_computed_tokens
515+ mm_positions = req_state .mm_positions
516+ for i , pos_info in enumerate (mm_positions ):
517+ start_pos = pos_info ["offset" ]
518+ num_encoder_tokens = pos_info ["length" ]
519+
520+ # The encoder output is needed if the two ranges overlap:
521+ # [num_computed_tokens,
522+ # num_computed_tokens + num_scheduled_tokens) and
523+ # [start_pos, start_pos + num_encoder_tokens)
524+ if start_pos >= num_computed_tokens + num_scheduled_tokens :
525+ # The encoder output is not needed in this step.
526+ break
527+ if start_pos + num_encoder_tokens <= num_computed_tokens :
528+ # The encoder output is already processed and stored
529+ # in the decoder's KV cache.
530+ continue
531+
532+ start_idx = max (num_computed_tokens - start_pos , 0 )
533+ end_idx = min (
534+ num_computed_tokens - start_pos + num_scheduled_tokens ,
535+ num_encoder_tokens )
536+ assert start_idx < end_idx
537+ assert req_id in self .encoder_cache
538+ assert i in self .encoder_cache [req_id ]
539+ encoder_output = self .encoder_cache [req_id ][i ]
540+ encoder_outputs .append (encoder_output [start_idx :end_idx ])
541+ return encoder_outputs
542+
429543 @torch .no_grad ()
430544 def execute_model (
431545 self ,
@@ -434,16 +548,42 @@ def execute_model(
434548 # Update cached state
435549 self ._update_states (scheduler_output )
436550
551+ if self .is_multimodal_model :
552+ # Run the multimodal encoder if any.
553+ self ._execute_encoder (scheduler_output )
554+ encoder_outputs = self ._gather_encoder_outputs (scheduler_output )
555+ else :
556+ encoder_outputs = []
557+
437558 # Prepare inputs
438559 attn_metadata , logits_indices = self ._prepare_inputs (scheduler_output )
439560 total_num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
440561
562+ if self .is_multimodal_model :
563+ # NOTE(woosuk): To unify token ids and soft tokens (vision
564+ # embeddings), we always use embeddings (rather than token ids)
565+ # as input to the multimodal model, even when the input is text.
566+ if encoder_outputs :
567+ inputs_embeds = self .model .get_input_embeddings (
568+ self .input_ids , encoder_outputs )
569+ else :
570+ inputs_embeds = self .model .get_input_embeddings (self .input_ids )
571+ input_ids = None
572+ else :
573+ # For text-only models, we use token ids as input.
574+ # While it is possible to use embeddings as input just like the
575+ # multimodal models, it is not desirable for performance since
576+ # then the embedding layer is not included in the CUDA graph.
577+ input_ids = self .input_ids
578+ inputs_embeds = None
579+
441580 # Run the decoder
442581 with set_forward_context (attn_metadata , self .vllm_config ):
443582 hidden_states = self .model (
444- token_ids = self . input_ids ,
445- position_ids = self .position_ids ,
583+ input_ids = input_ids ,
584+ positions = self .position_ids ,
446585 kv_caches = self .kv_caches ,
586+ inputs_embeds = inputs_embeds ,
447587 )
448588 hidden_states = hidden_states [:total_num_scheduled_tokens ]
449589 num_reqs = self .input_batch .num_reqs
@@ -538,14 +678,21 @@ def load_model(self) -> None:
538678 fullgraph = True ,
539679 dynamic = False )
540680
541- def dummy_run (
681+ def _dummy_run (
542682 self ,
543683 kv_caches ,
544684 num_tokens : int ,
545685 ) -> None :
546- input_ids = torch .zeros (num_tokens ,
547- dtype = torch .int32 ,
548- device = self .device )
686+ if self .is_multimodal_model :
687+ input_ids = None
688+ inputs_embeds = torch .zeros ((num_tokens , self .hidden_size ),
689+ dtype = self .dtype ,
690+ device = self .device )
691+ else :
692+ input_ids = torch .zeros ((num_tokens ),
693+ dtype = torch .int32 ,
694+ device = self .device )
695+ inputs_embeds = None
549696 position_ids = torch .zeros (num_tokens ,
550697 dtype = torch .int32 ,
551698 device = self .device )
@@ -571,7 +718,10 @@ def dummy_run(
571718 num_seqs = num_tokens ,
572719 )
573720
574- torch ._dynamo .mark_dynamic (input_ids , 0 )
721+ if self .is_multimodal_model :
722+ torch ._dynamo .mark_dynamic (inputs_embeds , 0 )
723+ else :
724+ torch ._dynamo .mark_dynamic (input_ids , 0 )
575725 torch ._dynamo .mark_dynamic (position_ids , 0 )
576726 torch ._dynamo .mark_dynamic (attn_metadata .slot_mapping , 0 )
577727 torch ._dynamo .mark_dynamic (attn_metadata .block_tables , 0 )
@@ -580,7 +730,12 @@ def dummy_run(
580730
581731 with set_forward_context (attn_metadata , self .vllm_config , 0 ):
582732 assert self .model is not None
583- self .model (input_ids , position_ids , kv_caches )
733+ self .model (
734+ input_ids = input_ids ,
735+ positions = position_ids ,
736+ kv_caches = kv_caches ,
737+ inputs_embeds = inputs_embeds ,
738+ )
584739
585740 def capture_model (self ) -> None :
586741 """Compile the model."""
@@ -590,11 +745,11 @@ def capture_model(self) -> None:
590745 start = time .perf_counter ()
591746 num_tokens = 16
592747 while True :
593- self .dummy_run (self .kv_caches , num_tokens )
748+ self ._dummy_run (self .kv_caches , num_tokens )
594749 logger .info (" -- num_tokens: %d" , num_tokens )
595750 xm .mark_step ()
596751 xm .wait_device_ops ()
597- if num_tokens >= self .scheduler_config . max_num_batched_tokens :
752+ if num_tokens >= self .max_num_tokens :
598753 break
599754 num_tokens *= 2
600755 end = time .perf_counter ()
@@ -647,17 +802,20 @@ def __init__(self, model: nn.Module):
647802
648803 def forward (
649804 self ,
650- token_ids : torch .Tensor ,
651- position_ids : torch .Tensor ,
805+ input_ids : torch .Tensor ,
806+ positions : torch .Tensor ,
652807 kv_caches : list [tuple [torch .Tensor , torch .Tensor ]],
808+ inputs_embeds : Optional [torch .Tensor ] = None ,
653809 ) -> torch .Tensor :
654810 """Executes the forward pass of the model and samples the next token.
655811
656812 Args:
657- token_ids : The input token IDs of shape [num_tokens].
658- position_ids : The input position IDs of shape [num_tokens].
813+ input_ids : The input token IDs of shape [num_tokens].
814+ positions : The input position IDs of shape [num_tokens].
659815 kv_caches: The key and value caches. They can be None during the
660816 memory profiling at initialization.
817+ inputs_embeds: The input embeddings of shape [num_tokens,
818+ hidden_size]. It is used for multimodal models.
661819 """
662820 # Skip this in memory profiling at initialization.
663821 if kv_caches [0 ][0 ].numel () > 0 :
@@ -684,9 +842,9 @@ def forward(
684842
685843 assert self .model is not None
686844 hidden_states = self .model (
687- token_ids ,
688- position_ids ,
689- kv_caches ,
845+ input_ids = input_ids ,
846+ positions = positions ,
847+ inputs_embeds = inputs_embeds ,
690848 )
691849
692850 return hidden_states
@@ -699,6 +857,12 @@ def compute_logits(
699857 logits = self .model .compute_logits (hidden_states , sampling_metadata )
700858 return logits
701859
860+ def get_multimodal_embeddings (self , * args , ** kwargs ):
861+ return self .model .get_multimodal_embeddings (* args , ** kwargs )
862+
863+ def get_input_embeddings (self , * args , ** kwargs ):
864+ return self .model .get_input_embeddings (* args , ** kwargs )
865+
702866
703867def _get_padded_number (n : int , multiple : int ) -> int :
704868 return ((n + multiple - 1 ) // multiple ) * multiple
0 commit comments