Skip to content

Commit 2d15a9b

Browse files
committed
fix: Mistral Small vision encoder with BS>1
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 06eba1e commit 2d15a9b

File tree

4 files changed

+89
-34
lines changed

4 files changed

+89
-34
lines changed

tensorrt_llm/runtime/multimodal_model_runner.py

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -909,29 +909,37 @@ def preprocess(self, pre_prompt, post_prompt, image, other_vision_inputs,
909909
elif self.model_type == 'pixtral':
910910
# Hold on to pixel_values and input_ids.
911911
dtype = str_dtype_to_torch(self.vision_precision)
912-
pixel_values = image["pixel_values"].to(device="cuda", dtype=dtype)
913-
input_ids = image["input_ids"].to(device="cuda")
914-
915912
# Shape of pixel values from the processor varies with the raw image.
916913
# So we create a new tensor with a fixed shape as expected by the vision
917914
# encoder and create a corresponding attention mask.
918915
image_size = self.image_size
919916
patch_size = self.patch_size
920917
d_min = torch.finfo(dtype).min
921918
num_patches = (image_size // patch_size)
922-
image = torch.full((1, 3, image_size, image_size),
923-
fill_value=0,
924-
dtype=dtype,
925-
device="cuda")
926-
attention_mask = torch.full((1, num_patches, num_patches),
927-
fill_value=d_min,
928-
dtype=dtype,
929-
device="cuda")
930-
h, w = pixel_values.shape[-2:]
931-
image[..., :h, :w] = pixel_values
932-
attention_mask[..., :h // patch_size, :w // patch_size] = 0
919+
padded_image = torch.full(
920+
(self.args.batch_size, 3, image_size, image_size),
921+
fill_value=0,
922+
dtype=dtype,
923+
device="cuda")
924+
padded_attention_mask = torch.full(
925+
(self.args.batch_size, num_patches, num_patches),
926+
fill_value=d_min,
927+
dtype=dtype,
928+
device="cuda")
929+
h, w, input_ids = [], [], []
930+
for img_idx in range(self.args.batch_size):
931+
pixel_values = image["pixel_values"][img_idx]
932+
img_h, img_w = pixel_values.shape[-2:]
933+
padded_image[img_idx, :, :img_h, :img_w] = pixel_values
934+
padded_attention_mask[img_idx, :img_h // patch_size, :img_w //
935+
patch_size] = 0
936+
input_ids.append(image["input_ids"][img_idx])
937+
h.append(img_h)
938+
w.append(img_w)
939+
940+
image = padded_image
933941
other_vision_inputs = {
934-
"attention_mask": attention_mask,
942+
"attention_mask": padded_attention_mask,
935943
}
936944
elif self.model_type == 'llava_next':
937945
input = image
@@ -1150,12 +1158,29 @@ def preprocess(self, pre_prompt, post_prompt, image, other_vision_inputs,
11501158
elif self.model_type == 'pixtral':
11511159
relevant_patch_size = self.patch_size * self.spatial_merge_size
11521160
output_img_size = self.image_size // relevant_patch_size
1153-
visual_features = visual_features.reshape(
1154-
output_img_size, output_img_size,
1155-
-1)[:h // relevant_patch_size, :w //
1156-
relevant_patch_size].flatten(0, 1)
1161+
# Note: max_h * max_w shall serve as the `tokens_per_task` in ptuning prompt table.
1162+
max_h = max(h) // relevant_patch_size
1163+
max_w = max(w) // relevant_patch_size
1164+
visual_embed_dim = visual_features.shape[-1]
1165+
relevant_visual_features = torch.zeros(self.args.batch_size,
1166+
max_h * max_w,
1167+
visual_embed_dim)
1168+
for img_idx in range(self.args.batch_size):
1169+
complete_features = visual_features[img_idx]
1170+
complete_features = complete_features.reshape(
1171+
output_img_size, output_img_size, visual_embed_dim)
1172+
relevant_h = h[img_idx] // relevant_patch_size
1173+
relevant_w = w[img_idx] // relevant_patch_size
1174+
flattened_features = complete_features[:relevant_h, :
1175+
relevant_w, :].flatten(
1176+
0, 1)
1177+
relevant_visual_features[img_idx, :relevant_h *
1178+
relevant_w, :] = flattened_features
1179+
visual_features = relevant_visual_features
11571180
input_ids = self.ptuning_setup_pixtral(input_ids=input_ids)
1158-
length = input_ids.shape[1]
1181+
# Note: length is not used for pixtral model downstream. Setting it to a list
1182+
# of length of input_ids causes errors downstream. So, supplying a placeholder.
1183+
length = input_ids[0].shape[0]
11591184

11601185
elif self.model_type == 'llava_next':
11611186
visual_features = LlavaNextUtils.rearrange_image_features(
@@ -2027,16 +2052,19 @@ def ptuning_setup_fuyu(self, input_ids, image_patches_indices):
20272052

20282053
def ptuning_setup_pixtral(self, input_ids):
20292054
# input_ids obtained from processor has token_ids for text as well as image tokens
2030-
# where each image token is represented the same image_token_index (10 for this model).
2055+
# where each image token is represented by the same image_token_index.
20312056
image_token_index = self.image_token_index
20322057
vocab_size = self.vocab_size
20332058
# Replace all image tokens with a unique token_id > text_vacab_size.
20342059
# This shall be used to lookup the prompt table.
2035-
replacer = vocab_size
2036-
for i in range(len(input_ids[0])):
2037-
if input_ids[0][i] == image_token_index:
2038-
input_ids[0][i] = replacer
2039-
replacer += 1
2060+
for img_idx in range(self.args.batch_size):
2061+
# Note: We reset replacer to text_vocab_size for each sample. This is as opposed to doing `replacer = vocab_size + img_idx * tokens_per_task`.
2062+
# That part of the look-up manipulation is done by the `task_ids` input to PromptEmbedding forward.
2063+
replacer = vocab_size
2064+
for token_idx in range(len(input_ids[img_idx])):
2065+
if input_ids[img_idx][token_idx] == image_token_index:
2066+
input_ids[img_idx][token_idx] = replacer
2067+
replacer += 1
20402068
return input_ids
20412069

20422070
def ptuning_setup_llava_next(self, visual_features, pre_prompt,
@@ -2166,7 +2194,24 @@ def load_images(image_paths):
21662194
if isinstance(image_path, str):
21672195
image_path = image_path.split(self.args.path_sep)
21682196
images = load_images(image_path)
2169-
2197+
elif "pixtral" in self.model_type:
2198+
if image_path is None:
2199+
image_urls = [
2200+
"https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png",
2201+
"https://www.ilankelman.org/stopsigns/australia.jpg",
2202+
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.png",
2203+
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
2204+
]
2205+
while len(image_urls) < self.args.batch_size:
2206+
image_urls *= 2
2207+
image_urls = image_urls[:self.args.batch_size]
2208+
self.args.image_path = ",".join(image_urls)
2209+
images = load_images(image_urls)
2210+
else:
2211+
if isinstance(image_path, str):
2212+
image_path = image_path.split(self.args.path_sep)
2213+
images = load_images(image_path)
2214+
images = [images] if not isinstance(images, list) else images
21702215
elif "nougat" in self.model_type:
21712216
filepath = hf_hub_download(
21722217
repo_id="hf-internal-testing/fixtures_docvqa",
@@ -2413,9 +2458,15 @@ def setup_inputs(self, input_text, raw_image, raw_audio=None):
24132458
post_prompt = "[/INST]"
24142459
prompt = pre_prompt + input_text + post_prompt
24152460
dtype = str_dtype_to_torch(self.vision_precision)
2416-
image = self.processor(text=prompt,
2417-
images=[raw_image],
2418-
return_tensors="pt").to(dtype)
2461+
image = {'pixel_values': [], 'input_ids': []}
2462+
for img_idx in range(self.args.batch_size):
2463+
image_info = self.processor(text=prompt,
2464+
images=[raw_image[img_idx]],
2465+
return_tensors="pt").to(dtype)
2466+
image['pixel_values'].append(image_info['pixel_values'].to(
2467+
self.device))
2468+
image['input_ids'].append(image_info['input_ids'][0].to(
2469+
self.device))
24192470

24202471
elif 'internvl' in self.model_type:
24212472
pre_prompt = "<|system|>\n你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。<|end|><|user|>\n<image>\n"
@@ -2619,7 +2670,7 @@ def setup_inputs(self, input_text, raw_image, raw_audio=None):
26192670
image = image.expand(
26202671
min(self.args.batch_size, len(input_text)), -1, -1,
26212672
-1).contiguous()
2622-
if image is not None:
2673+
if image is not None and isinstance(image, torch.Tensor):
26232674
image = image.to(self.device)
26242675
# Generate decoder_input_ids for enc-dec models
26252676
# Custom prompts can be added as:

tensorrt_llm/tools/multimodal_builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1627,8 +1627,12 @@ def attn_forward(self,
16271627
cos, sin = position_embeddings
16281628
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
16291629

1630+
# attention_mask is of shape [batch, patches].
1631+
mask = attention_mask[:, None, None, :]
1632+
16301633
attn_output = torch.nn.functional.scaled_dot_product_attention(
1631-
q, k, v, attn_mask=attention_mask).transpose(1, 2).contiguous()
1634+
q, k, v, attn_mask=mask).transpose(1, 2).contiguous()
1635+
16321636
attn_output = attn_output.reshape(batch, patches, -1)
16331637
attn_output = self.o_proj(attn_output)
16341638

tests/integration/test_lists/qa/examples_test_list.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ examples/test_multimodal.py::test_llm_multimodal_general[llava-v1.6-mistral-7b-h
185185
examples/test_multimodal.py::test_llm_multimodal_general[llava-v1.6-mistral-7b-hf-vision-trtllm-pp:1-tp:2-float16-bs:1-cpp_e2e:False-nb:1]
186186
examples/test_multimodal.py::test_llm_multimodal_general[llava-onevision-qwen2-7b-ov-hf-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
187187
examples/test_multimodal.py::test_llm_multimodal_general[llava-onevision-qwen2-7b-ov-hf-video-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
188-
examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1]
188+
examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1]
189189
examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1]
190190
examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1]
191191
examples/test_multimodal.py::test_llm_multimodal_general[video-neva-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1]

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ l0_h100:
241241
- examples/test_multimodal.py::test_llm_multimodal_general[Phi-3-vision-128k-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
242242
- examples/test_multimodal.py::test_llm_multimodal_general[Phi-3.5-vision-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
243243
- examples/test_multimodal.py::test_llm_multimodal_general[Phi-4-multimodal-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
244-
- examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1]
244+
- examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1]
245245
- examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] # 10 mins
246246
- examples/test_enc_dec.py::test_llm_enc_dec_mmlu[flan-t5-small-float32-tp:1-pp:1-nb:1-enable_fp8] # 7 mins
247247
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-enable_fp8] # 13 mins

0 commit comments

Comments
 (0)