@@ -41,6 +41,110 @@ def __call__(
4141 ...
4242
4343
44+ class BaseMultimodalInputProcessor :
45+ """
46+ Base class for multimodal input processors with default implementations
47+ of get_num_tokens_per_image and get_num_tokens_per_video methods.
48+
49+ This class provides default implementations that work with most AutoProcessor-based
50+ models. Specific processors can override these methods if they need custom logic.
51+ """
52+
53+ def get_num_tokens_per_image (
54+ self ,
55+ * ,
56+ image_width : int ,
57+ image_height : int ,
58+ ** kwargs ,
59+ ):
60+ """
61+ Calculate the number of tokens generated for an image.
62+
63+ Default implementation assumes the processor has either:
64+ 1. A 'processor' attribute with _get_num_multimodal_tokens method
65+ 2. A '_processor' attribute with _get_num_multimodal_tokens method
66+
67+ Override this method for custom implementations.
68+ """
69+ if hasattr (self , 'processor' ) and hasattr (self .processor ,
70+ '_get_num_multimodal_tokens' ):
71+ image_size = (image_height , image_width )
72+ num_image_tokens = self .processor ._get_num_multimodal_tokens (
73+ [image_size ], ** kwargs )["num_image_tokens" ][0 ]
74+ return num_image_tokens
75+ # Check for _processor attribute (e.g., Mistral3)
76+ elif hasattr (self , '_processor' ) and hasattr (
77+ self ._processor , '_get_num_multimodal_tokens' ):
78+ image_size = (image_height , image_width )
79+ num_image_tokens = self ._processor ._get_num_multimodal_tokens (
80+ [image_size ], ** kwargs )["num_image_tokens" ][0 ]
81+ return num_image_tokens
82+ else :
83+ raise NotImplementedError (
84+ f"get_num_tokens_per_image not implemented for { self .__class__ .__name__ } . "
85+ "Please override this method or ensure the processor has _get_num_multimodal_tokens method."
86+ )
87+
88+ def get_num_tokens_per_video (
89+ self ,
90+ * ,
91+ video_width : int ,
92+ video_height : int ,
93+ num_frames : int ,
94+ ** kwargs ,
95+ ):
96+ """
97+ Calculate the number of tokens generated for a video.
98+
99+ Default implementation assumes the processor has either:
100+ 1. A 'processor' attribute with _get_num_multimodal_tokens method
101+ 2. A '_processor' attribute with _get_num_multimodal_tokens method
102+
103+ Override this method for custom implementations.
104+ """
105+ if hasattr (self , 'processor' ) and hasattr (self .processor ,
106+ '_get_num_multimodal_tokens' ):
107+ video_size = (num_frames , video_height , video_width )
108+ # Try to get video tokens directly
109+ try :
110+ num_video_tokens = self .processor ._get_num_multimodal_tokens (
111+ video_sizes = [video_size ], ** kwargs )["num_video_tokens" ][0 ]
112+ return num_video_tokens
113+ except Exception :
114+ # Fallback: treat video as sequence of frames
115+ num_tokens_per_frame = self .get_num_tokens_per_image (
116+ image_width = video_width ,
117+ image_height = video_height ,
118+ ** kwargs )
119+ temporal_patch_size = self .temporal_patch_size if hasattr (
120+ self , 'temporal_patch_size' ) else 1
121+ return num_tokens_per_frame * num_frames // temporal_patch_size
122+ # Check for _processor attribute (e.g., Mistral3)
123+ # TODO: unify the naming convention for the processor attribute
124+ elif hasattr (self , '_processor' ) and hasattr (
125+ self ._processor , '_get_num_multimodal_tokens' ):
126+ video_size = (num_frames , video_height , video_width )
127+ # Try to get video tokens directly
128+ try :
129+ num_video_tokens = self ._processor ._get_num_multimodal_tokens (
130+ video_sizes = [video_size ], ** kwargs )["num_video_tokens" ][0 ]
131+ return num_video_tokens
132+ except Exception :
133+ # Fallback: treat video as sequence of frames
134+ num_tokens_per_frame = self .get_num_tokens_per_image (
135+ image_width = video_width ,
136+ image_height = video_height ,
137+ ** kwargs )
138+ temporal_patch_size = self .temporal_patch_size if hasattr (
139+ self , 'temporal_patch_size' ) else 1
140+ return num_tokens_per_frame * num_frames // temporal_patch_size
141+ else :
142+ raise NotImplementedError (
143+ f"get_num_tokens_per_video not implemented for { self .__class__ .__name__ } . "
144+ "Please override this method or ensure the processor has _get_num_multimodal_tokens method."
145+ )
146+
147+
44148class DefaultInputProcessor (InputProcessor ):
45149 """Preprocess the inputs to the model."""
46150
@@ -327,6 +431,8 @@ def multimodal_hashing_process(
327431 assert 'multi_modal_data' in inputs , "multi_modal_data must be provided for hashing support."
328432 mm_data = inputs ['multi_modal_data' ]
329433 num_mm_tokens = find_mm_token_lengths (mm_data , input_processor )
434+ # TODO: here we assume there is only one modality for now
435+ num_mm_tokens = next (iter (num_mm_tokens .values ()))
330436 if len (num_mm_tokens ) > 0 :
331437 mm_hashes = apply_mm_hashes (mm_data , hash_lib )
332438 prompt_token_ids , extra_processed_inputs = input_processor (
@@ -358,8 +464,8 @@ def input_processor_wrapper(
358464 modalities = list (set (inputs ['multi_modal_data' ].keys ())
359465 ) if 'multi_modal_data' in inputs else []
360466 if len (modalities ) > 0 :
361- # NOTE: tensorrt_llm/inputs/multimodal.py:find_mm_token_lengths only supports image data for now
362- if len (modalities ) == 1 and modalities [ 0 ] == "image" :
467+ # TODO: support multiple modalities for multimodal hashing ( for kv cache reuse, chunked prefill, etc.)
468+ if len (modalities ) == 1 :
363469 # only try multimodal hashing if the inputs only contain image data
364470 if input_processor .multimodal_hashing_supported is not None :
365471 use_multimodal_hashing = input_processor .multimodal_hashing_supported
0 commit comments