Skip to content

Commit b859c4d

Browse files
committed
Address comments
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent cf445ee commit b859c4d

File tree

5 files changed

+58
-86
lines changed

5 files changed

+58
-86
lines changed

examples/llm-api/quickstart_multimodal.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def parse_arguments():
145145
parser = add_lora_args(parser)
146146
args = parser.parse_args()
147147

148-
args.disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite
148+
args.disable_kv_cache_reuse = False # kv cache reuse does not work for multimodal, force overwrite
149149
if args.kv_cache_fraction is None:
150150
args.kv_cache_fraction = 0.6 # lower the default kv cache fraction for multimodal
151151

@@ -177,6 +177,19 @@ def main():
177177

178178
llm, sampling_params = setup_llm(args, lora_config=lora_config)
179179

180+
# from tensorrt_llm import MultimodalEncoder, SamplingParams
181+
# sampling_params = SamplingParams(max_tokens=args.max_tokens)
182+
# llm = MultimodalEncoder(
183+
# model=args.model_dir,
184+
# backend='pytorch',
185+
# disable_overlap_scheduler=args.disable_overlap_scheduler,
186+
# max_seq_len=args.max_seq_len,
187+
# max_batch_size=args.max_batch_size,
188+
# max_num_tokens=args.max_num_tokens,
189+
# trust_remote_code=args.trust_remote_code,
190+
# )
191+
192+
180193
image_format = args.image_format
181194
if args.model_type is not None:
182195
model_type = args.model_type
@@ -197,8 +210,8 @@ def main():
197210
model_dir=str(llm._hf_model_dir),
198211
model_type=model_type,
199212
modality=args.modality,
200-
prompts=args.prompt,
201-
media=args.media,
213+
prompts=[args.prompt[0], args.prompt[0]],
214+
media=[args.media[0], args.media[0]],
202215
image_data_format=image_format,
203216
num_frames=args.num_frames,
204217
device=args.device)
@@ -211,7 +224,6 @@ def main():
211224
outputs = llm.generate(
212225
inputs,
213226
sampling_params,
214-
lora_request=lora_request,
215227
)
216228

217229
for i, output in enumerate(outputs):

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(self,
4444
trust_remote_code=trust_remote_code)
4545

4646
self.tllm_multimodal_token_id = self.model_config.vocab_size + 1
47+
# temporal patch size for video frames
4748
self.temporal_patch_size = getattr(model_config.vision_config,
4849
'temporal_patch_size', 1)
4950

tensorrt_llm/inputs/multimodal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,8 @@ def _hash_image(image):
445445
# Hash each frame with a separator to avoid collisions between [A,B] and [AB]
446446
for frame in image:
447447
hasher.update(b"<frame>")
448+
if isinstance(frame, torch.Tensor):
449+
frame = frame.detach().cpu().contiguous()
448450
hasher.update(serialize_item(frame))
449451
else:
450452
hasher.update(serialize_item(image))

tensorrt_llm/inputs/registry.py

Lines changed: 39 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,24 @@ class BaseMultimodalInputProcessor:
5050
models. Specific processors can override these methods if they need custom logic.
5151
"""
5252

53+
@property
54+
def get_num_multimodal_tokens(self):
55+
"""
56+
Get the Hugging Face processor's '_get_num_multimodal_tokens' method.
57+
58+
"""
59+
if hasattr(self, 'processor') and hasattr(self.processor,
60+
'_get_num_multimodal_tokens'):
61+
return self.processor._get_num_multimodal_tokens
62+
elif hasattr(self, '_processor') and hasattr(
63+
self._processor, '_get_num_multimodal_tokens'):
64+
return self._processor._get_num_multimodal_tokens
65+
else:
66+
raise NotImplementedError(
67+
f"get_num_multimodal_tokens not implemented for {self.__class__.__name__}. "
68+
"Please override this method or ensure the processor has _get_num_multimodal_tokens method."
69+
)
70+
5371
def get_num_tokens_per_image(
5472
self,
5573
*,
@@ -60,30 +78,14 @@ def get_num_tokens_per_image(
6078
"""
6179
Calculate the number of tokens generated for an image.
6280
63-
Default implementation assumes the processor has either:
64-
1. A 'processor' attribute with _get_num_multimodal_tokens method
65-
2. A '_processor' attribute with _get_num_multimodal_tokens method
81+
This (default) method delegates to the Hugging Face processor's '_get_num_multimodal_tokens' method.
82+
Returns the token count for the given image.
6683
67-
Override this method for custom implementations.
84+
Subclasses can override this method to provide custom logic to calculate the number of tokens.
6885
"""
69-
if hasattr(self, 'processor') and hasattr(self.processor,
70-
'_get_num_multimodal_tokens'):
71-
image_size = (image_height, image_width)
72-
num_image_tokens = self.processor._get_num_multimodal_tokens(
73-
[image_size], **kwargs)["num_image_tokens"][0]
74-
return num_image_tokens
75-
# Check for _processor attribute (e.g., Mistral3)
76-
elif hasattr(self, '_processor') and hasattr(
77-
self._processor, '_get_num_multimodal_tokens'):
78-
image_size = (image_height, image_width)
79-
num_image_tokens = self._processor._get_num_multimodal_tokens(
80-
[image_size], **kwargs)["num_image_tokens"][0]
81-
return num_image_tokens
82-
else:
83-
raise NotImplementedError(
84-
f"get_num_tokens_per_image not implemented for {self.__class__.__name__}. "
85-
"Please override this method or ensure the processor has _get_num_multimodal_tokens method."
86-
)
86+
image_size = (image_height, image_width)
87+
return self.get_num_multimodal_tokens([image_size],
88+
**kwargs)["num_image_tokens"][0]
8789

8890
def get_num_tokens_per_video(
8991
self,
@@ -96,53 +98,23 @@ def get_num_tokens_per_video(
9698
"""
9799
Calculate the number of tokens generated for a video.
98100
99-
Default implementation assumes the processor has either:
100-
1. A 'processor' attribute with _get_num_multimodal_tokens method
101-
2. A '_processor' attribute with _get_num_multimodal_tokens method
101+
This (default) method delegates to the Hugging Face processor's '_get_num_multimodal_tokens' method.
102+
Returns the token count for the given video.
102103
103-
Override this method for custom implementations.
104+
Subclasses can override this method to provide custom logic to calculate the number of tokens.
104105
"""
105-
if hasattr(self, 'processor') and hasattr(self.processor,
106-
'_get_num_multimodal_tokens'):
107-
video_size = (num_frames, video_height, video_width)
108-
# Try to get video tokens directly
109-
try:
110-
num_video_tokens = self.processor._get_num_multimodal_tokens(
111-
video_sizes=[video_size], **kwargs)["num_video_tokens"][0]
112-
return num_video_tokens
113-
except Exception:
114-
# Fallback: treat video as sequence of frames
115-
num_tokens_per_frame = self.get_num_tokens_per_image(
116-
image_width=video_width,
117-
image_height=video_height,
118-
**kwargs)
119-
temporal_patch_size = self.temporal_patch_size if hasattr(
120-
self, 'temporal_patch_size') else 1
121-
return num_tokens_per_frame * num_frames // temporal_patch_size
122-
# Check for _processor attribute (e.g., Mistral3)
123-
# TODO: unify the naming convention for the processor attribute
124-
elif hasattr(self, '_processor') and hasattr(
125-
self._processor, '_get_num_multimodal_tokens'):
126-
video_size = (num_frames, video_height, video_width)
127-
# Try to get video tokens directly
128-
try:
129-
num_video_tokens = self._processor._get_num_multimodal_tokens(
130-
video_sizes=[video_size], **kwargs)["num_video_tokens"][0]
131-
return num_video_tokens
132-
except Exception:
133-
# Fallback: treat video as sequence of frames
134-
num_tokens_per_frame = self.get_num_tokens_per_image(
135-
image_width=video_width,
136-
image_height=video_height,
137-
**kwargs)
138-
temporal_patch_size = self.temporal_patch_size if hasattr(
139-
self, 'temporal_patch_size') else 1
140-
return num_tokens_per_frame * num_frames // temporal_patch_size
141-
else:
142-
raise NotImplementedError(
143-
f"get_num_tokens_per_video not implemented for {self.__class__.__name__}. "
144-
"Please override this method or ensure the processor has _get_num_multimodal_tokens method."
145-
)
106+
video_size = (num_frames, video_height, video_width)
107+
try:
108+
num_video_tokens = self.get_num_multimodal_tokens(
109+
video_sizes=[video_size], **kwargs)["num_video_tokens"][0]
110+
return num_video_tokens
111+
except Exception:
112+
# Fallback: treat video as sequence of frames
113+
num_tokens_per_frame = self.get_num_tokens_per_image(
114+
image_width=video_width, image_height=video_height, **kwargs)
115+
temporal_patch_size = self.temporal_patch_size if hasattr(
116+
self, 'temporal_patch_size') else 1
117+
return num_tokens_per_frame * num_frames // temporal_patch_size
146118

147119

148120
class DefaultInputProcessor(InputProcessor):

tests/unittest/_torch/multimodal/test_find_num_image_tokens.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from tensorrt_llm import MultimodalEncoder
99
from tensorrt_llm._torch.models.modeling_llava_next import \
1010
LlavaNextInputProcessor
11-
from tensorrt_llm._torch.models.modeling_mistral import Mistral3InputProcessor
1211
from tensorrt_llm._torch.models.modeling_qwen2vl import \
1312
Qwen2VLInputProcessorBase
1413
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
@@ -47,11 +46,6 @@ def multimodal_model_configs():
4746
'hf_model_dir': 'Qwen/Qwen2.5-VL-3B-Instruct',
4847
'model_type': 'qwen2_5_vl',
4948
},
50-
'mistral-small-3.1': {
51-
'hf_model_dir':
52-
'/home/scratch.trt_llm_data/llm-models/Mistral-Small-3.1-24B-Instruct-2503',
53-
'model_type': 'mistral3',
54-
},
5549
}
5650
return model_configs
5751

@@ -149,9 +143,6 @@ def test_get_num_tokens_per_image(model_key, multimodal_model_configs):
149143
image_height=image_height,
150144
num_frames=1,
151145
do_resize=True)
152-
elif model_type == 'mistral':
153-
predicted_num_tokens = input_processor.get_num_tokens_per_image(
154-
image_width=image_width, image_height=image_height)
155146
else:
156147
raise ValueError(f"Unsupported model type: {model_type}")
157148

@@ -216,12 +207,6 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs):
216207
model_config=model_config_dict,
217208
tokenizer=tokenizer,
218209
trust_remote_code=True)
219-
elif model_type == 'mistral':
220-
input_processor = Mistral3InputProcessor(
221-
model_path=encoder_model_dir,
222-
model_config=model_config_dict,
223-
tokenizer=tokenizer,
224-
trust_remote_code=True)
225210
else:
226211
pytest.fail(f"Unsupported model type: {model_type}")
227212

0 commit comments

Comments
 (0)