Skip to content

Commit a508aee

Browse files
move mrope_config and mm_embedding under MultimodalParams
Signed-off-by: yechank <[email protected]>
1 parent 0ab4503 commit a508aee

File tree

10 files changed

+169
-114
lines changed

10 files changed

+169
-114
lines changed

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,10 @@ def __call__(
100100
"pixel_values"]
101101
input_ids = preprocess_outputs[0]["mm_processor_kwargs"]["input_ids"]
102102
mm_features = self._process(pixel_values)
103+
multimodal_data = {}
104+
multimodal_data["multimodal_embedding"] = mm_features
103105
return input_ids[0].to(torch.int32).tolist(), {
104-
"mm_embedding": mm_features
106+
"multimodal_data": multimodal_data
105107
}
106108

107109

@@ -163,7 +165,7 @@ def forward(
163165

164166
multimodal_params = kwargs.get("multimodal_params", [])
165167
mm_embed = [
166-
multimodal_param.multimodal_embedding
168+
multimodal_param.multimodal_data["multimodal_embedding"]
167169
for multimodal_param in multimodal_params
168170
]
169171
assert mm_embed == [] or len(

tensorrt_llm/_torch/models/modeling_hyperclovax.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import math
3+
import os
34
from functools import partial
45
from itertools import chain
56
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -25,6 +26,8 @@
2526
from .modeling_siglip import SiglipVisionModel
2627
from .modeling_utils import register_auto_model
2728

29+
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
30+
2831

2932
# Copied from HyperCLOVAX-SEED-Vision-Instruct-3B/modeling_hyperclovax.py
3033
def select_best_resolution(original_size: tuple,
@@ -969,8 +972,8 @@ def __init__(self, model_config: ModelConfig):
969972
self.model_config = model_config
970973
if hasattr(self, "llm"):
971974
return
972-
973-
self.mm_encoder = HCXVisionModel(model_config)
975+
if not DISAGG:
976+
self.mm_encoder = HCXVisionModel(model_config)
974977
llm_model_config = copy.deepcopy(model_config)
975978
llm_model_config.pretrained_config = PretrainedConfig.from_dict(
976979
llm_model_config.pretrained_config.language_config)
@@ -1026,7 +1029,13 @@ def forward(
10261029
assert len(multimodal_params) == num_context_requests == len(
10271030
multimodal_params
10281031
), f"Number of multimodal tensors ({len(multimodal_params)}) should be equal to number of context requests ({num_context_requests}) in the batch."
1029-
mm_embeds = self.mm_encoder.forward(multimodal_params)
1032+
if not DISAGG:
1033+
mm_embeds = self.mm_encoder.forward(multimodal_params)
1034+
else:
1035+
mm_embeds = [
1036+
multimodal_param.multimodal_data["multimodal_embedding"]
1037+
for multimodal_param in multimodal_params
1038+
]
10301039

10311040
input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens,
10321041
input_ids, mm_embeds)

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,10 @@ def __call__(
851851
mm_embeds = self.encoder.multi_modal_projector(mm_embeds)
852852
# for fuse_input_embeds
853853
token_ids[token_ids == self.image_token_index] = self.vocab_size + 1
854-
return token_ids.tolist(), {"mm_embedding": mm_embeds}
854+
855+
multimodal_data = {}
856+
multimodal_data["multimodal_embedding"] = mm_embeds
857+
return token_ids.tolist(), {"multimodal_data": multimodal_data}
855858
else:
856859
return processed["input_ids"].squeeze().tolist(), {}
857860

@@ -882,8 +885,12 @@ def forward(
882885
spec_metadata: Optional[SpecMetadata] = None,
883886
**kwargs,
884887
) -> torch.Tensor:
885-
mm_embed = kwargs.get("multi_modal_data", [])
886-
if mm_embed:
888+
multimodal_params = kwargs.get("multimodal_params", [])
889+
if multimodal_params:
890+
mm_embed = [
891+
multimodal_param.multimodal_data["multimodal_embedding"]
892+
for multimodal_param in multimodal_params
893+
]
887894
_, inputs_embeds = fuse_input_embeds(self.model.embed_tokens,
888895
input_ids, mm_embed)
889896
return super().forward(attn_metadata,

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,10 @@ def __call__(
210210
mm_features = torch.stack(
211211
[self._process(tensor) for tensor in mm_tensor])
212212
fused_input_ids, mm_features = self._postprocess(input_ids, mm_features)
213+
multimodal_data = {}
214+
multimodal_data["multimodal_embedding"] = mm_features
213215
return fused_input_ids.to(torch.int32).tolist(), {
214-
"mm_embedding": mm_features
216+
"multimodal_data": multimodal_data
215217
}
216218

217219

@@ -273,7 +275,7 @@ def forward(
273275

274276
multimodal_params = kwargs.get("multimodal_params", [])
275277
mm_embed = [
276-
multimodal_param.multimodal_embedding
278+
multimodal_param.multimodal_data["multimodal_embedding"]
277279
for multimodal_param in multimodal_params
278280
]
279281
assert mm_embed == [] or len(

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import os
23
from typing import Any, Dict, List, Optional, Tuple, Union
34

45
import torch
@@ -20,6 +21,8 @@
2021
from .modeling_multimodal_utils import fuse_input_embeds
2122
from .modeling_utils import register_auto_model
2223

24+
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
25+
2326

2427
class Qwen2VLInputProcessorBase(InputProcessor):
2528

@@ -322,7 +325,8 @@ def get_mrope_config(
322325
concat_cos_sin = concat_cos_sin.reshape(concat_cos_sin.shape[0], -1)
323326
mrope_config = {}
324327
mrope_config['mrope_rotary_cos_sin'] = concat_cos_sin.to('cpu')
325-
mrope_config['mrope_position_deltas'] = mrope_position_deltas.to('cpu')
328+
mrope_config['mrope_position_deltas'] = mrope_position_deltas.to(
329+
'cpu').to(torch.int32)
326330
return mrope_config
327331

328332
@torch.inference_mode()
@@ -364,11 +368,11 @@ def __call__(
364368
processed_inputs.get('video_grid_thw', None),
365369
processed_inputs.get('attention_mask', None),
366370
processed_inputs.get('second_per_grid_ts', None))
371+
multimodal_data["mrope_config"] = mrope_config
367372

368373
fused_input_ids = self._postprocess(input_ids[0])
369374

370375
return fused_input_ids.to(torch.int32).tolist(), {
371-
"mrope_config": mrope_config,
372376
"multimodal_data": multimodal_data,
373377
}
374378

@@ -411,16 +415,14 @@ def _parse_and_batch_multimodal_data(
411415

412416
for multimodal_param in multimodal_params:
413417
# Process images if present
414-
if "image" in multimodal_param.multimodal_data and multimodal_param.multimodal_data[
415-
"image"]:
418+
if multimodal_param.multimodal_data.get("image") is not None:
416419
pixel_values_list.append(
417420
multimodal_param.multimodal_data["image"]["pixel_values"])
418421
image_grid_thw_list.append(
419422
multimodal_param.multimodal_data["image"]["image_grid_thw"])
420423

421424
# Process videos if present
422-
if "video" in multimodal_param.multimodal_data and multimodal_param.multimodal_data[
423-
"video"]:
425+
if multimodal_param.multimodal_data.get("video") is not None:
424426
pixel_values_videos_list.append(
425427
multimodal_param.multimodal_data["video"]
426428
["pixel_values_videos"])
@@ -457,6 +459,8 @@ def forward(self, multimodal_params: List[MultimodalParams]):
457459

458460
mm_content_data, mm_extra_data = self._parse_and_batch_multimodal_data(
459461
multimodal_params)
462+
print(f"mm_content_data: {mm_content_data}")
463+
print(f"mm_extra_data: {mm_extra_data}")
460464
pixel_values = mm_content_data.get("pixel_values", None)
461465
pixel_values_videos = mm_content_data.get("pixel_values_videos", None)
462466

@@ -478,7 +482,6 @@ def forward(self, multimodal_params: List[MultimodalParams]):
478482
pixel_values_videos = pixel_values_videos.to(self.visual.dtype)
479483
embeds.append(
480484
self.visual(pixel_values_videos, grid_thw=video_grid_thw))
481-
482485
return embeds
483486

484487

@@ -526,16 +529,19 @@ def _parse_mrope_config(
526529
mrope_config = {}
527530
mrope_rotary_cos_sin_list = []
528531
mrope_position_deltas_list = []
529-
530532
for multimodal_param in multimodal_params:
531-
if hasattr(multimodal_param,
532-
'mrope_config') and multimodal_param.mrope_config:
533-
if 'mrope_rotary_cos_sin' in multimodal_param.mrope_config:
533+
if multimodal_param.multimodal_data and multimodal_param.multimodal_data.get(
534+
'mrope_config'):
535+
if multimodal_param.multimodal_data['mrope_config'].get(
536+
'mrope_rotary_cos_sin') is not None:
534537
mrope_rotary_cos_sin_list.append(
535-
multimodal_param.mrope_config['mrope_rotary_cos_sin'])
536-
if 'mrope_position_deltas' in multimodal_param.mrope_config:
538+
multimodal_param.multimodal_data['mrope_config']
539+
['mrope_rotary_cos_sin'])
540+
if multimodal_param.multimodal_data['mrope_config'].get(
541+
'mrope_position_deltas') is not None:
537542
mrope_position_deltas_list.append(
538-
multimodal_param.mrope_config['mrope_position_deltas'])
543+
multimodal_param.multimodal_data['mrope_config']
544+
['mrope_position_deltas'])
539545

540546
if mrope_rotary_cos_sin_list:
541547
mrope_config['mrope_rotary_cos_sin'] = torch.cat(
@@ -544,6 +550,8 @@ def _parse_mrope_config(
544550
if mrope_position_deltas_list:
545551
mrope_config['mrope_position_deltas'] = torch.cat(
546552
mrope_position_deltas_list, dim=0)
553+
print(f"mrope_config: {mrope_config}")
554+
return mrope_config
547555

548556
@torch.inference_mode()
549557
def forward(
@@ -568,8 +576,14 @@ def forward(
568576
mrope_config = {}
569577

570578
if len(multimodal_params) > 0:
571-
mm_embeds = self.mm_encoder.forward(
572-
multimodal_params[:num_context_requests])
579+
if not DISAGG:
580+
mm_embeds = self.mm_encoder.forward(
581+
multimodal_params[:num_context_requests])
582+
else:
583+
mm_embeds = [
584+
multimodal_param.multimodal_data["multimodal_embedding"]
585+
for multimodal_param in multimodal_params
586+
]
573587
mrope_config = self._parse_mrope_config(multimodal_params)
574588

575589
input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens,
@@ -592,8 +606,9 @@ class Qwen2VLModel(Qwen2VLModelBase):
592606

593607
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
594608
**kwargs):
595-
self.mm_encoder = Qwen2VisionModelBase(model_config,
596-
Qwen2VLForConditionalGeneration)
609+
if not DISAGG:
610+
self.mm_encoder = Qwen2VisionModelBase(
611+
model_config, Qwen2VLForConditionalGeneration)
597612
super().__init__(model_config, *args, **kwargs)
598613

599614

@@ -603,6 +618,7 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
603618

604619
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
605620
**kwargs):
606-
self.mm_encoder = Qwen2VisionModelBase(
607-
model_config, Qwen2_5_VLForConditionalGeneration)
621+
if not DISAGG:
622+
self.mm_encoder = Qwen2VisionModelBase(
623+
model_config, Qwen2_5_VLForConditionalGeneration)
608624
super().__init__(model_config, *args, **kwargs)

tensorrt_llm/_torch/models/modeling_vila.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,8 +1107,10 @@ def __call__(
11071107
) # use_fast uses Pytorch GPU preprocessing, otherwise uses PIL CPU preprocessing
11081108
mm_features = self._process(mm_tensor, block_sizes)
11091109
fused_input_ids, mm_features = self._postprocess(input_ids, mm_features)
1110+
multimodal_data = {}
1111+
multimodal_data["multimodal_embedding"] = mm_features
11101112
return fused_input_ids.to(torch.int32).tolist(), {
1111-
"mm_embedding": mm_features
1113+
"multimodal_data": multimodal_data
11121114
}
11131115

11141116

@@ -1163,7 +1165,7 @@ def forward(
11631165
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
11641166
multimodal_params = kwargs.get("multimodal_params", [])
11651167
mm_embed = [
1166-
multimodal_param.multimodal_embedding
1168+
multimodal_param.multimodal_data["multimodal_embedding"]
11671169
for multimodal_param in multimodal_params
11681170
]
11691171

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,30 +1188,15 @@ def _prepare_tp_inputs(
11881188
prompt_lengths.append(len(prompt_tokens))
11891189
past_seen_token_num = begin_compute
11901190
num_cached_tokens_per_seq.append(past_seen_token_num)
1191+
request.py_batch_idx = py_batch_idx(request)
11911192

1192-
multimodal_embedding = request.multimodal_embedding
1193-
if multimodal_embedding is not None:
1194-
# TODO: Visit later once we have the SharedTensor.
1195-
multimodal_embedding = multimodal_embedding.pin_memory(
1196-
) if multimodal_embedding.device == 'cpu' else multimodal_embedding
1197-
multimodal_embedding = multimodal_embedding.to(
1198-
'cuda', non_blocking=True)
1199-
1200-
mrope_rotary_cos_sin = request.mrope_rotary_cos_sin
1201-
if mrope_rotary_cos_sin is not None:
1202-
# TODO: Visit later once we have the SharedTensor.
1203-
mrope_rotary_cos_sin = mrope_rotary_cos_sin.pin_memory(
1204-
) if mrope_rotary_cos_sin.device == 'cpu' else mrope_rotary_cos_sin
1205-
mrope_rotary_cos_sin = mrope_rotary_cos_sin.to(
1206-
'cuda', non_blocking=True)
1207-
1208-
# Create MultimodalParams from request data
1193+
# Multimodal
12091194
multimodal_params = MultimodalParams(
1210-
multimodal_embedding=multimodal_embedding,
1211-
mrope_config={'mrope_rotary_cos_sin': mrope_rotary_cos_sin}
1212-
if mrope_rotary_cos_sin is not None else {},
1213-
multimodal_data=request.py_multimodal_data,
1214-
)
1195+
multimodal_data=request.py_multimodal_data, )
1196+
multimodal_params.strip_for_context()
1197+
multimodal_params.to_device("multimodal_data",
1198+
"cuda",
1199+
pin_memory=True)
12151200

12161201
if multimodal_params.has_content():
12171202
multimodal_params_list.append(multimodal_params)
@@ -1243,20 +1228,15 @@ def _prepare_tp_inputs(
12431228
extend_requests.append(request)
12441229
else:
12451230
generation_requests.append(request)
1246-
1247-
# Handle generation request multimodal params
1248-
mrope_position_deltas = request.mrope_position_deltas
1249-
if mrope_position_deltas is not None:
1250-
mrope_position_deltas_tensor = torch.tensor(
1251-
[mrope_position_deltas], dtype=torch.int32, pin_memory=True)
1252-
multimodal_params = MultimodalParams(
1253-
mrope_config={
1254-
'mrope_position_deltas':
1255-
mrope_position_deltas_tensor.to('cuda',
1256-
non_blocking=True)
1257-
})
1258-
if multimodal_params.has_content():
1259-
multimodal_params_list.append(multimodal_params)
1231+
# Multimodal
1232+
multimodal_params = MultimodalParams(
1233+
multimodal_data=request.py_multimodal_data, )
1234+
multimodal_params.strip_for_generation()
1235+
multimodal_params.to_device("multimodal_data",
1236+
"cuda",
1237+
pin_memory=True)
1238+
if multimodal_params.has_content():
1239+
multimodal_params_list.append(multimodal_params)
12601240
extend_requests += extend_dummy_requests
12611241

12621242
if not self._disable_overlap_scheduler and self.is_spec_decode:

0 commit comments

Comments
 (0)