5555from  vllm .multimodal  import  MULTIMODAL_REGISTRY 
5656from  vllm .multimodal .inputs  import  (ImageItem , ModalityData ,
5757                                    MultiModalFieldConfig , MultiModalKwargs ,
58-                                     NestedTensors ,  VideoItem )
58+                                     VideoItem )
5959from  vllm .multimodal .parse  import  (ImageSize , ModalityDataItems ,
6060                                   MultiModalDataItems , MultiModalDataParser )
6161from  vllm .multimodal .processing  import  (BaseMultiModalProcessor ,
@@ -1233,7 +1233,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
12331233        return  modalities 
12341234
12351235    def  get_multimodal_embeddings (
1236-             self , ** kwargs ) ->  Optional [List [ Tuple [ NestedTensors ,  str ] ]]:
1236+             self , ** kwargs ) ->  Optional [tuple [ torch . Tensor , ... ]]:
12371237
12381238        modalities  =  self ._parse_and_validate_multimodal_inputs (** kwargs )
12391239        if  not  modalities :
@@ -1260,8 +1260,7 @@ def get_multimodal_embeddings(
12601260    def  get_input_embeddings (
12611261        self ,
12621262        input_ids : torch .Tensor ,
1263-         multimodal_embeddings : Optional [List [Tuple [NestedTensors ,
1264-                                                    str ]]] =  None ,
1263+         multimodal_embeddings : Optional [tuple [torch .Tensor , ...]] =  None ,
12651264    ) ->  torch .Tensor :
12661265        inputs_embeds  =  self .language_model .get_input_embeddings (input_ids )
12671266        if  multimodal_embeddings  is  not   None :
@@ -1270,6 +1269,33 @@ def get_input_embeddings(
12701269                [self .config .image_token_id , self .config .video_token_id ])
12711270        return  inputs_embeds 
12721271
1272+     def  get_input_embeddings_v0 (
1273+         self ,
1274+         input_ids : torch .Tensor ,
1275+         image_input : Optional [tuple [torch .Tensor , ...]] =  None ,
1276+         video_input : Optional [tuple [torch .Tensor , ...]] =  None ,
1277+     ) ->  torch .Tensor :
1278+ 
1279+         inputs_embeds  =  self .get_input_embeddings (input_ids )
1280+         if  image_input  is  not   None :
1281+             image_embeds  =  self ._process_image_input (image_input )
1282+             inputs_embeds  =  merge_multimodal_embeddings (
1283+                 input_ids ,
1284+                 inputs_embeds ,
1285+                 image_embeds ,
1286+                 placeholder_token_id = self .config .image_token_id ,
1287+             )
1288+ 
1289+         if  video_input  is  not   None :
1290+             video_embeds  =  self ._process_video_input (video_input )
1291+             inputs_embeds  =  merge_multimodal_embeddings (
1292+                 input_ids ,
1293+                 inputs_embeds ,
1294+                 video_embeds ,
1295+                 placeholder_token_id = self .config .video_token_id ,
1296+             )
1297+         return  inputs_embeds 
1298+ 
12731299    def  forward (
12741300        self ,
12751301        input_ids : torch .Tensor ,
@@ -1303,22 +1329,25 @@ def forward(
13031329        if  intermediate_tensors  is  not   None :
13041330            inputs_embeds  =  None 
13051331
1306-         # NOTE: In v1, inputs_embeds is always generated at model runner, this 
1307-         # condition is for v0 compatibility. 
1332+         # NOTE: In v1, inputs_embeds is always generated at model runner from 
1333+         # `get_multimodal_embeddings` and `get_input_embeddings`, this 
1334+         # condition is only for v0 compatibility. 
13081335        elif  inputs_embeds  is  None :
1309-             multimodal_embeddings  =  self .get_multimodal_embeddings (** kwargs )
1310- 
1311-             # We need to check for usage of mrope here in case there is 
1312-             # multimodal data. 
1313-             # TODO (ywang96): move this to model runner in V1. 
1314-             if  multimodal_embeddings  is  not   None  and  uses_mrope (self .config ):
1315-                 assert  positions .ndim  ==  2  and  positions .size (0 ) ==  3 , (
1316-                     "multimodal section rotary embedding requires " 
1317-                     f"(3, seq_len) positions, but got { positions .size ()}  " )
1318- 
1319-             inputs_embeds  =  self .get_input_embeddings (input_ids ,
1320-                                                       multimodal_embeddings )
1321-             input_ids  =  None 
1336+             image_input  =  self ._parse_and_validate_image_input (** kwargs )
1337+             video_input  =  self ._parse_and_validate_video_input (** kwargs )
1338+ 
1339+             if  image_input  is  None  and  video_input  is  None :
1340+                 inputs_embeds  =  None 
1341+             else :
1342+                 if  uses_mrope (self .config ):
1343+                     assert  positions .ndim  ==  2  and  positions .size (0 ) ==  3 , (
1344+                         "multimodal section rotary embedding requires " 
1345+                         f"(3, seq_len) positions, but got { positions .size ()}  " )
1346+                 inputs_embeds  =  self .get_input_embeddings_v0 (
1347+                     input_ids ,
1348+                     image_input = image_input ,
1349+                     video_input = video_input )
1350+                 input_ids  =  None 
13221351
13231352        hidden_states  =  self .language_model .model (
13241353            input_ids = input_ids ,
0 commit comments