Skip to content

Commit cf445ee

Browse files
committed
Enable video encoder and generalize finding mm_token_length
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent a419b77 commit cf445ee

File tree

6 files changed

+286
-76
lines changed

6 files changed

+286
-76
lines changed

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
from tensorrt_llm.inputs.multimodal import MultimodalParams
1616

17-
from ...inputs import (ExtraProcessedInputs, InputProcessor,
18-
MultimodalPlaceholderMetadata,
17+
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
18+
InputProcessor, MultimodalPlaceholderMetadata,
1919
MultimodalPlaceholderPlacement, TextPrompt,
2020
register_input_processor)
2121
from ...llmapi.utils import download_hf_model
@@ -32,7 +32,7 @@
3232
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
3333

3434

35-
class LlavaNextInputProcessor(InputProcessor):
35+
class LlavaNextInputProcessor(BaseMultimodalInputProcessor, InputProcessor):
3636

3737
def __init__(self,
3838
model_path: str,
@@ -56,17 +56,6 @@ def __init__(self,
5656
self.vocab_size = model_config.vocab_size
5757
self.config = model_config.vision_config
5858

59-
def get_num_tokens_per_image(
60-
self,
61-
*,
62-
image_width: int,
63-
image_height: int,
64-
) -> int:
65-
image_size = (image_height, image_width)
66-
num_image_tokens = self.processor._get_num_multimodal_tokens(
67-
[image_size])["num_image_tokens"][0]
68-
return num_image_tokens
69-
7059
def _postprocess(
7160
self, input_ids: torch.Tensor, mm_features: Union[torch.Tensor,
7261
List[torch.Tensor]]

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
88
PreTrainedModel, Qwen2_5_VLForConditionalGeneration,
99
Qwen2VLForConditionalGeneration)
10-
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
1110

1211
from tensorrt_llm.inputs.multimodal import MultimodalParams
1312

1413
from ..._utils import nvtx_range_debug
1514
from ...functional import RopeEmbeddingUtils, RotaryScalingType
16-
from ...inputs import (ExtraProcessedInputs, InputProcessor,
17-
MultimodalPlaceholderMetadata,
15+
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
16+
InputProcessor, MultimodalPlaceholderMetadata,
1817
MultimodalPlaceholderPlacement, TextPrompt,
1918
register_input_processor)
2019
from ...logger import logger
@@ -29,7 +28,7 @@
2928
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
3029

3130

32-
class Qwen2VLInputProcessorBase(InputProcessor):
31+
class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor, InputProcessor):
3332

3433
def __init__(self,
3534
model_path: str,
@@ -45,6 +44,8 @@ def __init__(self,
4544
trust_remote_code=trust_remote_code)
4645

4746
self.tllm_multimodal_token_id = self.model_config.vocab_size + 1
47+
self.temporal_patch_size = getattr(model_config.vision_config,
48+
'temporal_patch_size', 1)
4849

4950
@classmethod
5051
def get_rope_index(
@@ -220,38 +221,6 @@ def get_rope_index(
220221
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
221222
return position_ids, mrope_position_deltas
222223

223-
def get_num_tokens_per_image(
224-
self,
225-
*,
226-
image_width: int,
227-
image_height: int,
228-
num_frames: int = 1,
229-
do_resize: bool = True,
230-
):
231-
patch_size = self.model_config.vision_config.patch_size
232-
merge_size = self.model_config.vision_config.spatial_merge_size
233-
temporal_patch_size = self.model_config.vision_config.temporal_patch_size
234-
if do_resize:
235-
resized_height, resized_width = smart_resize(
236-
height=image_height,
237-
width=image_width,
238-
factor=patch_size * merge_size,
239-
min_pixels=self.processor.image_processor.min_pixels,
240-
max_pixels=self.processor.image_processor.max_pixels,
241-
)
242-
image_width, image_height = resized_width, resized_height
243-
244-
padded_num_frames = num_frames + num_frames % temporal_patch_size
245-
246-
grid_t = max(padded_num_frames // temporal_patch_size, 1)
247-
grid_h = image_height // patch_size
248-
grid_w = image_width // patch_size
249-
250-
num_patches = grid_t * grid_h * grid_w
251-
num_vision_tokens = num_patches // (merge_size**2)
252-
253-
return num_vision_tokens
254-
255224
def _preprocess(self, text: dict[str, any], mm_data: dict[str, any],
256225
mm_processor_kwargs: Dict[str, Any]):
257226
images = mm_data.get("image")

tensorrt_llm/inputs/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
22
from .multimodal import MultimodalInput
3-
from .registry import (ExtraProcessedInputs, InputProcessor,
4-
MultimodalPlaceholderMetadata,
3+
from .registry import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
4+
InputProcessor, MultimodalPlaceholderMetadata,
55
MultimodalPlaceholderPlacement, create_input_processor,
66
create_input_processor_with_hash,
77
register_input_processor)
@@ -27,6 +27,7 @@
2727
"create_input_processor_with_hash",
2828
"register_input_processor",
2929
"ExtraProcessedInputs",
30+
"BaseMultimodalInputProcessor",
3031
"MultimodalPlaceholderMetadata",
3132
"MultimodalPlaceholderPlacement",
3233
"ConversationMessage",

tensorrt_llm/inputs/multimodal.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -435,13 +435,20 @@ def apply_mm_hashes(mm_data: Dict[str, Any],
435435
"""Apply hashing to multimodal data items."""
436436

437437
def _hash_image(image):
438-
# only support single modality w/ PIL.Image.Image for now
439438
# TODO: possible hash collision w/ this simplified version (vllm/PR/17378)
440439
hasher = hash_lib()
441440
if isinstance(image, torch.Tensor):
442-
# TODO: Device tensor hashing is an open issue. Limited hashing to CPU for now.
443-
image = image.cpu()
444-
hasher.update(serialize_item(image))
441+
# Ensure tensor is on CPU and contiguous for consistent hashing
442+
image = image.detach().cpu().contiguous()
443+
hasher.update(serialize_item(image))
444+
elif isinstance(image, list):
445+
# Hash each frame with a separator to avoid collisions between [A,B] and [AB]
446+
for frame in image:
447+
hasher.update(b"<frame>")
448+
hasher.update(serialize_item(frame))
449+
else:
450+
hasher.update(serialize_item(image))
451+
445452
return hasher.hexdigest()
446453

447454
mm_items = {
@@ -483,31 +490,35 @@ def find_mm_token_lengths(mm_data: Dict[str, Any],
483490
num_mm_tokens = {}
484491

485492
for modality, items in mm_items.items():
486-
if modality != "image":
487-
#TODO: support other modalities
488-
raise ValueError(
489-
f"Unsupported modality: {modality}. Only 'image' modality is currently supported for hashing."
490-
)
491-
if not hasattr(input_processor, "get_num_tokens_per_image"):
492-
#TODO: backward compatibility for models that don't yet have get_num_tokens_per_image implemented
493-
#TODO: only support qwen2_vl for now
493+
if not hasattr(input_processor, f"get_num_tokens_per_{modality}"):
494494
raise AttributeError(
495-
f"Input processor {type(input_processor).__name__} does not have 'get_num_tokens_per_image' method required for multimodal hashing."
495+
f"Input processor {type(input_processor).__name__} does not have 'get_num_tokens_per_{modality}' method required for multimodal hashing."
496496
)
497497

498498
modality_token_lengths = []
499499
for item in items:
500-
if isinstance(item, torch.Tensor):
501-
item = ToPILImage()(item)
502-
num_tokens = input_processor.get_num_tokens_per_image(
503-
image_width=item.width,
504-
image_height=item.height,
505-
)
506-
modality_token_lengths.append(num_tokens)
500+
if modality == "image":
501+
if isinstance(item, torch.Tensor):
502+
item = ToPILImage()(item)
503+
num_tokens = input_processor.get_num_tokens_per_image(
504+
image_width=item.width,
505+
image_height=item.height,
506+
)
507+
modality_token_lengths.append(num_tokens)
508+
elif modality == "video":
509+
assert isinstance(item, list), "Video must be a list of frames"
510+
if isinstance(item[0], torch.Tensor):
511+
item = [ToPILImage()(frame) for frame in item]
512+
num_tokens = input_processor.get_num_tokens_per_video(
513+
video_width=item[0].width,
514+
video_height=item[0].height,
515+
num_frames=len(item),
516+
)
517+
modality_token_lengths.append(num_tokens)
507518

508519
num_mm_tokens[modality] = modality_token_lengths
509520

510-
return num_mm_tokens['image'] # flatten all mm instances to a single list
521+
return num_mm_tokens # flatten all mm instances to a single list
511522

512523

513524
def find_mm_token_positions(input_ids: Union[torch.Tensor, List[int],

tensorrt_llm/inputs/registry.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,110 @@ def __call__(
4141
...
4242

4343

44+
class BaseMultimodalInputProcessor:
45+
"""
46+
Base class for multimodal input processors with default implementations
47+
of get_num_tokens_per_image and get_num_tokens_per_video methods.
48+
49+
This class provides default implementations that work with most AutoProcessor-based
50+
models. Specific processors can override these methods if they need custom logic.
51+
"""
52+
53+
def get_num_tokens_per_image(
54+
self,
55+
*,
56+
image_width: int,
57+
image_height: int,
58+
**kwargs,
59+
):
60+
"""
61+
Calculate the number of tokens generated for an image.
62+
63+
Default implementation assumes the processor has either:
64+
1. A 'processor' attribute with _get_num_multimodal_tokens method
65+
2. A '_processor' attribute with _get_num_multimodal_tokens method
66+
67+
Override this method for custom implementations.
68+
"""
69+
if hasattr(self, 'processor') and hasattr(self.processor,
70+
'_get_num_multimodal_tokens'):
71+
image_size = (image_height, image_width)
72+
num_image_tokens = self.processor._get_num_multimodal_tokens(
73+
[image_size], **kwargs)["num_image_tokens"][0]
74+
return num_image_tokens
75+
# Check for _processor attribute (e.g., Mistral3)
76+
elif hasattr(self, '_processor') and hasattr(
77+
self._processor, '_get_num_multimodal_tokens'):
78+
image_size = (image_height, image_width)
79+
num_image_tokens = self._processor._get_num_multimodal_tokens(
80+
[image_size], **kwargs)["num_image_tokens"][0]
81+
return num_image_tokens
82+
else:
83+
raise NotImplementedError(
84+
f"get_num_tokens_per_image not implemented for {self.__class__.__name__}. "
85+
"Please override this method or ensure the processor has _get_num_multimodal_tokens method."
86+
)
87+
88+
def get_num_tokens_per_video(
89+
self,
90+
*,
91+
video_width: int,
92+
video_height: int,
93+
num_frames: int,
94+
**kwargs,
95+
):
96+
"""
97+
Calculate the number of tokens generated for a video.
98+
99+
Default implementation assumes the processor has either:
100+
1. A 'processor' attribute with _get_num_multimodal_tokens method
101+
2. A '_processor' attribute with _get_num_multimodal_tokens method
102+
103+
Override this method for custom implementations.
104+
"""
105+
if hasattr(self, 'processor') and hasattr(self.processor,
106+
'_get_num_multimodal_tokens'):
107+
video_size = (num_frames, video_height, video_width)
108+
# Try to get video tokens directly
109+
try:
110+
num_video_tokens = self.processor._get_num_multimodal_tokens(
111+
video_sizes=[video_size], **kwargs)["num_video_tokens"][0]
112+
return num_video_tokens
113+
except Exception:
114+
# Fallback: treat video as sequence of frames
115+
num_tokens_per_frame = self.get_num_tokens_per_image(
116+
image_width=video_width,
117+
image_height=video_height,
118+
**kwargs)
119+
temporal_patch_size = self.temporal_patch_size if hasattr(
120+
self, 'temporal_patch_size') else 1
121+
return num_tokens_per_frame * num_frames // temporal_patch_size
122+
# Check for _processor attribute (e.g., Mistral3)
123+
# TODO: unify the naming convention for the processor attribute
124+
elif hasattr(self, '_processor') and hasattr(
125+
self._processor, '_get_num_multimodal_tokens'):
126+
video_size = (num_frames, video_height, video_width)
127+
# Try to get video tokens directly
128+
try:
129+
num_video_tokens = self._processor._get_num_multimodal_tokens(
130+
video_sizes=[video_size], **kwargs)["num_video_tokens"][0]
131+
return num_video_tokens
132+
except Exception:
133+
# Fallback: treat video as sequence of frames
134+
num_tokens_per_frame = self.get_num_tokens_per_image(
135+
image_width=video_width,
136+
image_height=video_height,
137+
**kwargs)
138+
temporal_patch_size = self.temporal_patch_size if hasattr(
139+
self, 'temporal_patch_size') else 1
140+
return num_tokens_per_frame * num_frames // temporal_patch_size
141+
else:
142+
raise NotImplementedError(
143+
f"get_num_tokens_per_video not implemented for {self.__class__.__name__}. "
144+
"Please override this method or ensure the processor has _get_num_multimodal_tokens method."
145+
)
146+
147+
44148
class DefaultInputProcessor(InputProcessor):
45149
"""Preprocess the inputs to the model."""
46150

@@ -327,6 +431,8 @@ def multimodal_hashing_process(
327431
assert 'multi_modal_data' in inputs, "multi_modal_data must be provided for hashing support."
328432
mm_data = inputs['multi_modal_data']
329433
num_mm_tokens = find_mm_token_lengths(mm_data, input_processor)
434+
# TODO: here we assume there is only one modality for now
435+
num_mm_tokens = next(iter(num_mm_tokens.values()))
330436
if len(num_mm_tokens) > 0:
331437
mm_hashes = apply_mm_hashes(mm_data, hash_lib)
332438
prompt_token_ids, extra_processed_inputs = input_processor(
@@ -358,8 +464,8 @@ def input_processor_wrapper(
358464
modalities = list(set(inputs['multi_modal_data'].keys())
359465
) if 'multi_modal_data' in inputs else []
360466
if len(modalities) > 0:
361-
# NOTE: tensorrt_llm/inputs/multimodal.py:find_mm_token_lengths only supports image data for now
362-
if len(modalities) == 1 and modalities[0] == "image":
467+
# TODO: support multiple modalities for multimodal hashing (for kv cache reuse, chunked prefill, etc.)
468+
if len(modalities) == 1:
363469
# only try multimodal hashing if the inputs only contain image data
364470
if input_processor.multimodal_hashing_supported is not None:
365471
use_multimodal_hashing = input_processor.multimodal_hashing_supported

0 commit comments

Comments
 (0)