11import copy
2+ import os
23from typing import Any , Dict , List , Optional , Tuple , Union
34
45import torch
2021from .modeling_multimodal_utils import fuse_input_embeds
2122from .modeling_utils import register_auto_model
2223
24+ DISAGG = os .getenv ('TLLM_MULTIMODAL_DISAGGREGATED' , '0' ) == '1'
25+
2326
2427class Qwen2VLInputProcessorBase (InputProcessor ):
2528
@@ -322,7 +325,8 @@ def get_mrope_config(
322325 concat_cos_sin = concat_cos_sin .reshape (concat_cos_sin .shape [0 ], - 1 )
323326 mrope_config = {}
324327 mrope_config ['mrope_rotary_cos_sin' ] = concat_cos_sin .to ('cpu' )
325- mrope_config ['mrope_position_deltas' ] = mrope_position_deltas .to ('cpu' )
328+ mrope_config ['mrope_position_deltas' ] = mrope_position_deltas .to (
329+ 'cpu' ).to (torch .int32 )
326330 return mrope_config
327331
328332 @torch .inference_mode ()
@@ -364,11 +368,11 @@ def __call__(
364368 processed_inputs .get ('video_grid_thw' , None ),
365369 processed_inputs .get ('attention_mask' , None ),
366370 processed_inputs .get ('second_per_grid_ts' , None ))
371+ multimodal_data ["mrope_config" ] = mrope_config
367372
368373 fused_input_ids = self ._postprocess (input_ids [0 ])
369374
370375 return fused_input_ids .to (torch .int32 ).tolist (), {
371- "mrope_config" : mrope_config ,
372376 "multimodal_data" : multimodal_data ,
373377 }
374378
@@ -411,16 +415,14 @@ def _parse_and_batch_multimodal_data(
411415
412416 for multimodal_param in multimodal_params :
413417 # Process images if present
414- if "image" in multimodal_param .multimodal_data and multimodal_param .multimodal_data [
415- "image" ]:
418+ if multimodal_param .multimodal_data .get ("image" ) is not None :
416419 pixel_values_list .append (
417420 multimodal_param .multimodal_data ["image" ]["pixel_values" ])
418421 image_grid_thw_list .append (
419422 multimodal_param .multimodal_data ["image" ]["image_grid_thw" ])
420423
421424 # Process videos if present
422- if "video" in multimodal_param .multimodal_data and multimodal_param .multimodal_data [
423- "video" ]:
425+ if multimodal_param .multimodal_data .get ("video" ) is not None :
424426 pixel_values_videos_list .append (
425427 multimodal_param .multimodal_data ["video" ]
426428 ["pixel_values_videos" ])
@@ -457,6 +459,8 @@ def forward(self, multimodal_params: List[MultimodalParams]):
457459
458460 mm_content_data , mm_extra_data = self ._parse_and_batch_multimodal_data (
459461 multimodal_params )
462+ print (f"mm_content_data: { mm_content_data } " )
463+ print (f"mm_extra_data: { mm_extra_data } " )
460464 pixel_values = mm_content_data .get ("pixel_values" , None )
461465 pixel_values_videos = mm_content_data .get ("pixel_values_videos" , None )
462466
@@ -478,7 +482,6 @@ def forward(self, multimodal_params: List[MultimodalParams]):
478482 pixel_values_videos = pixel_values_videos .to (self .visual .dtype )
479483 embeds .append (
480484 self .visual (pixel_values_videos , grid_thw = video_grid_thw ))
481-
482485 return embeds
483486
484487
@@ -526,16 +529,19 @@ def _parse_mrope_config(
526529 mrope_config = {}
527530 mrope_rotary_cos_sin_list = []
528531 mrope_position_deltas_list = []
529-
530532 for multimodal_param in multimodal_params :
531- if hasattr (multimodal_param ,
532- 'mrope_config' ) and multimodal_param .mrope_config :
533- if 'mrope_rotary_cos_sin' in multimodal_param .mrope_config :
533+ if multimodal_param .multimodal_data and multimodal_param .multimodal_data .get (
534+ 'mrope_config' ):
535+ if multimodal_param .multimodal_data ['mrope_config' ].get (
536+ 'mrope_rotary_cos_sin' ) is not None :
534537 mrope_rotary_cos_sin_list .append (
535- multimodal_param .mrope_config ['mrope_rotary_cos_sin' ])
536- if 'mrope_position_deltas' in multimodal_param .mrope_config :
538+ multimodal_param .multimodal_data ['mrope_config' ]
539+ ['mrope_rotary_cos_sin' ])
540+ if multimodal_param .multimodal_data ['mrope_config' ].get (
541+ 'mrope_position_deltas' ) is not None :
537542 mrope_position_deltas_list .append (
538- multimodal_param .mrope_config ['mrope_position_deltas' ])
543+ multimodal_param .multimodal_data ['mrope_config' ]
544+ ['mrope_position_deltas' ])
539545
540546 if mrope_rotary_cos_sin_list :
541547 mrope_config ['mrope_rotary_cos_sin' ] = torch .cat (
@@ -544,6 +550,8 @@ def _parse_mrope_config(
544550 if mrope_position_deltas_list :
545551 mrope_config ['mrope_position_deltas' ] = torch .cat (
546552 mrope_position_deltas_list , dim = 0 )
553+ print (f"mrope_config: { mrope_config } " )
554+ return mrope_config
547555
548556 @torch .inference_mode ()
549557 def forward (
@@ -568,8 +576,14 @@ def forward(
568576 mrope_config = {}
569577
570578 if len (multimodal_params ) > 0 :
571- mm_embeds = self .mm_encoder .forward (
572- multimodal_params [:num_context_requests ])
579+ if not DISAGG :
580+ mm_embeds = self .mm_encoder .forward (
581+ multimodal_params [:num_context_requests ])
582+ else :
583+ mm_embeds = [
584+ multimodal_param .multimodal_data ["multimodal_embedding" ]
585+ for multimodal_param in multimodal_params
586+ ]
573587 mrope_config = self ._parse_mrope_config (multimodal_params )
574588
575589 input_ids , input_embeds = fuse_input_embeds (self .llm .model .embed_tokens ,
@@ -592,8 +606,9 @@ class Qwen2VLModel(Qwen2VLModelBase):
592606
593607 def __init__ (self , model_config : ModelConfig [PretrainedConfig ], * args ,
594608 ** kwargs ):
595- self .mm_encoder = Qwen2VisionModelBase (model_config ,
596- Qwen2VLForConditionalGeneration )
609+ if not DISAGG :
610+ self .mm_encoder = Qwen2VisionModelBase (
611+ model_config , Qwen2VLForConditionalGeneration )
597612 super ().__init__ (model_config , * args , ** kwargs )
598613
599614
@@ -603,6 +618,7 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
603618
604619 def __init__ (self , model_config : ModelConfig [PretrainedConfig ], * args ,
605620 ** kwargs ):
606- self .mm_encoder = Qwen2VisionModelBase (
607- model_config , Qwen2_5_VLForConditionalGeneration )
621+ if not DISAGG :
622+ self .mm_encoder = Qwen2VisionModelBase (
623+ model_config , Qwen2_5_VLForConditionalGeneration )
608624 super ().__init__ (model_config , * args , ** kwargs )
0 commit comments