11import copy
22import os
3- from typing import List , Optional , Tuple
3+ from typing import List , Optional , Tuple , Dict
44
55import numpy as np
66import torch
@@ -118,6 +118,128 @@ def get_num_tokens_per_image(
118118 )
119119 return unpadded_feature_size + newline_feature_size + base_feature_size
120120
121+ def _postprocess (self , input_ids , mm_features ):
122+ # Define model specific variables here before shared logic
123+ mm_tokens = torch .tensor ([self .model_config .image_token_index
124+ ]).to (input_ids .device )
125+ model_hidden_size = self .model_config .text_config .hidden_size
126+ vocab_size = self .model_config .text_config .vocab_size
127+ start_len = end_len = 0 # for llava, need not append start/end token around each image token
128+ # End model specific variables
129+
130+ ## find mm token positions in input_ids
131+ mm_token_positions = torch .where (torch .isin (input_ids , mm_tokens ))[0 ]
132+ num_medias = num_mm_tokens = len (mm_token_positions )
133+ if num_medias > 1 and isinstance (mm_features , torch .Tensor ):
134+ mm_features = list (
135+ mm_features .split (mm_features .shape [0 ] // num_medias ))
136+
137+ if isinstance (mm_features , torch .Tensor ):
138+ # 1 prompt + 1 media
139+ # "split" means what a single mm_token in the input_ids should represent
140+ # image: one split --> one frame
141+ # video: one split --> N frames
142+ num_frames , mm_feature_length , mm_hidden_dim = mm_features .shape
143+ mm_lengths_per_split = [mm_feature_length * num_frames ]
144+ mm_lengths_per_frame = [mm_feature_length ]
145+ elif isinstance (mm_features , list ):
146+ # 1 prompt + N media
147+ num_frames = len (mm_features ) if mm_features [0 ].dim () == 2 else sum (
148+ [f .shape [0 ] for f in mm_features ])
149+ mm_lengths_per_split = [
150+ f .shape [0 ] if f .dim () == 2 else f .shape [0 ] * f .shape [1 ]
151+ for f in mm_features
152+ ]
153+ mm_lengths_per_frame = [
154+ f .shape [0 ] if f .dim () == 2 else f .shape [1 ] for f in mm_features
155+ ]
156+ mm_hidden_dim = mm_features [0 ].shape [- 1 ]
157+ mm_features = torch .cat (mm_features , dim = 0 )
158+ else :
159+ raise ValueError (
160+ f"Invalid multimodal features type: { type (mm_features )} " )
161+ mm_total_length = sum (mm_lengths_per_split )
162+ assert mm_hidden_dim == model_hidden_size , "Multimodal embedding_dim must match model hidden_size"
163+
164+ ## split input_ids into segments by isolating mm tokens
165+ mm_split_positions = torch .cat (
166+ [mm_token_positions , mm_token_positions + 1 ]).unique ()
167+ input_ids_splits = list (input_ids .tensor_split (mm_split_positions .cpu (
168+ ))) # len(input_ids_splits) = num_segments after mm tokens are isolated
169+ mm_ids_splits = list (
170+ torch .arange (vocab_size ,
171+ vocab_size + mm_total_length ,
172+ device = input_ids .device ).split (mm_lengths_per_split )
173+ ) # len(mm_ids_splits) = num_mm_segments
174+
175+ for i , mm_ids in enumerate (mm_ids_splits ):
176+ mm_ids = mm_ids .reshape (- 1 , mm_lengths_per_frame [i ])
177+ mm_ids_splits [i ] = mm_ids .flatten ()
178+
179+ ## replace mm token ids with the expanded out-of-vocab ids
180+ mm_split_idx = 0
181+ for i , split in enumerate (input_ids_splits ):
182+ if torch .isin (split , mm_tokens ).any ().item ():
183+ input_ids_splits [i ] = mm_ids_splits [mm_split_idx ]
184+ mm_split_idx += 1
185+ assert mm_split_idx == len (
186+ mm_ids_splits ), "All mm_ids_splits should be consumed"
187+
188+ ## concat text & mm input_ids, wrap mm feature in prompt tuning config
189+ fused_input_ids = torch .cat (input_ids_splits ).to (
190+ device = input_ids .device )
191+ fused_length = len (input_ids ) + mm_total_length + num_frames * (
192+ start_len + end_len ) - num_medias
193+ assert len (
194+ fused_input_ids
195+ ) == fused_length , f"Fused input_ids length { len (fused_input_ids )} should match the sum of text and multimodal embedding lengths { fused_length } "
196+
197+ # [num_frames, feature_length, hidden_dim] -> [num_frames * feature_length, hidden_dim]
198+ mm_features = mm_features .view (- 1 , mm_features .shape [- 1 ])
199+ return fused_input_ids , mm_features
200+
201+
202+ def attach_multimodal_embeddings (
203+ self , inputs : TextPrompt ,
204+ multimodal_embedding : Dict [str , List [torch .Tensor ]],
205+ sampling_params : SamplingParams
206+ ) -> Tuple [List [int ], Optional [ExtraProcessedInputs ]]:
207+ """
208+ Attach pre-processed multimodal embeddings into text token stream for LlavaNext model.
209+ This method skips vision processing and works with externally provided embeddings.
210+ It replaces/expands image placeholders in the text with appropriate tokens and prepares
211+ the embeddings for model forward pass.
212+ Args:
213+ inputs: Text prompt containing image placeholders
214+ multimodal_embedding: Dictionary containing pre-processed image embedding data
215+ Returns:
216+ Tuple of (token_ids, extra_processed_inputs) where:
217+ - token_ids: List of processed token IDs with image placeholders
218+ - extra_processed_inputs: Optional dictionary containing multimodal embeddings
219+ """
220+ text_prompt = inputs .get ("prompt" )
221+ if not text_prompt :
222+ raise ValueError ("Text prompt is required but not provided" )
223+
224+
225+
226+ if not isinstance (multimodal_embedding , dict ):
227+ raise ValueError ("multimodal_embedding must be a dictionary" )
228+
229+ if 'image' not in multimodal_embedding :
230+ raise ValueError (
231+ "Only image modality is supported for external multimodal embedding"
232+ )
233+
234+ input_ids = self .tokenizer (
235+ text_prompt , return_tensors = "pt" ).input_ids [0 ]
236+ mm_features = torch .stack (multimodal_embedding ['image' ])
237+ fused_input_ids , mm_features = self ._postprocess (input_ids , mm_features )
238+ multimodal_data = {}
239+ multimodal_data ["multimodal_embedding" ] = mm_features
240+ return fused_input_ids .to (torch .int32 ).tolist (), {
241+ "multimodal_data" : multimodal_data
242+ }
121243
122244 @torch .inference_mode ()
123245 def __call__ (
@@ -158,9 +280,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
158280 ** kwargs ) -> None :
159281 super ().__init__ ()
160282 self .model_config = model_config
161- pretrained_config = model_config .pretrained_config
283+ self . pretrained_config = model_config .pretrained_config
162284 self .device = f"cuda:{ model_config .mapping .rank } "
163- model_path = pretrained_config ._name_or_path
285+ model_path = self . pretrained_config ._name_or_path
164286
165287 # Determine the actual local path for model files
166288 if os .path .isdir (model_path ):
@@ -200,7 +322,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
200322 self .vision_tower = hf_vision_tower .to (self .device )
201323 else :
202324 vision_model_config = ModelConfig (
203- pretrained_config = model_config .pretrained_config .vision_config ,
325+ pretrained_config = self .pretrained_config .vision_config ,
204326 attn_backend = "TRTLLM" )
205327 self .vision_tower = CLIPVisionModel (vision_model_config ).to (
206328 self .device ).to (self .dtype )
@@ -210,13 +332,13 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
210332 self .mm_projector = hf_mm_projector
211333 self .image_newline = hf_image_newline
212334 self .vision_feature_select_strategy = getattr (
213- model_config .pretrained_config , "vision_feature_select_strategy" ,
335+ self .pretrained_config , "vision_feature_select_strategy" ,
214336 "default" )
215337
216338 self .post_config ()
217339
218340 def post_config (self ):
219- self .config = self .model_config . pretrained_config .vision_config
341+ self .config = self .pretrained_config .vision_config
220342
221343 # Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L284
222344 def pack_image_features (self ,
@@ -234,7 +356,7 @@ def pack_image_features(self,
234356
235357 num_patch_height , num_patch_width = get_anyres_image_grid_shape (
236358 image_sizes [image_idx ],
237- self .model_config . pretrained_config .image_grid_pinpoints ,
359+ self .pretrained_config .image_grid_pinpoints ,
238360 self .config .image_size ,
239361 )
240362
@@ -296,7 +418,7 @@ def forward(self, multimodal_params: List[MultimodalParams]):
296418 image_num_patches = [
297419 image_size_to_num_patches (
298420 image_size = imsize ,
299- grid_pinpoints = self .model_config . pretrained_config .image_grid_pinpoints ,
421+ grid_pinpoints = self .pretrained_config .image_grid_pinpoints ,
300422 patch_size = self .config .image_size ,
301423 ) for imsize in image_sizes
302424 ]
@@ -396,7 +518,13 @@ def forward(
396518 mm_embeds = []
397519 if len (multimodal_params ) > 0 :
398520 if not DISAGG :
399- mm_embeds = self .mm_encoder .forward (multimodal_params )
521+ if multimodal_params [0 ].multimodal_data .get ("multimodal_embedding" , None ) is not None :
522+ mm_embeds = [
523+ multimodal_param .multimodal_data ["multimodal_embedding" ]
524+ for multimodal_param in multimodal_params
525+ ]
526+ else :
527+ mm_embeds = self .mm_encoder .forward (multimodal_params )
400528 else :
401529 mm_embeds = [
402530 multimodal_param .multimodal_data ["multimodal_embedding" ]
0 commit comments