Skip to content

Commit 456d5cd

Browse files
committed
nvbugs-5331031; nvbugs-5344203 - address intermittent issues with Mistral Small multimodal for BS=8
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 2b56957 commit 456d5cd

File tree

4 files changed

+24
-5
lines changed

4 files changed

+24
-5
lines changed

tensorrt_llm/runtime/multimodal_model_runner.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch.nn.functional as F
1616
from cuda import cudart
1717
from huggingface_hub import hf_hub_download
18-
from PIL import Image
18+
from PIL import Image, UnidentifiedImageError
1919
from safetensors import safe_open
2020
from torch import nn
2121
from transformers import (AutoConfig, AutoModelForCausalLM, AutoProcessor,
@@ -2173,8 +2173,23 @@ def load_images(image_paths):
21732173
if image_path.startswith("http") or image_path.startswith(
21742174
"https"):
21752175
logger.info(f"downloading image from url {image_path}")
2176-
response = requests.get(image_path, timeout=5)
2177-
image = Image.open(BytesIO(response.content)).convert("RGB")
2176+
try:
2177+
response = requests.get(image_path, timeout=5)
2178+
response.raise_for_status()
2179+
if 'image' not in response.headers.get(
2180+
'Content-Type', ''):
2181+
raise Exception(
2182+
f"URL does not point to an image: {image_path}."
2183+
)
2184+
image = Image.open(BytesIO(
2185+
response.content)).convert("RGB")
2186+
except (UnidentifiedImageError, IOError):
2187+
raise Exception(
2188+
f"Cannot identify image file at URL: {image_path}.")
2189+
except Exception as e:
2190+
raise Exception(
2191+
f"Failed to download image from url {image_path}: {e}"
2192+
)
21782193
else:
21792194
image = Image.open(image_path).convert("RGB")
21802195
images.append(image)

tests/integration/defs/examples/test_multimodal.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717

1818
import pytest
19+
import torch
1920
from defs.common import convert_weights, venv_check_call, venv_mpi_check_call
2021
from defs.conftest import get_device_memory, skip_post_blackwell, skip_pre_ada
2122
from defs.trt_test_alternative import check_call
@@ -75,6 +76,10 @@ def _test_llm_multimodal_general(llm_venv,
7576
cpp_e2e=False,
7677
num_beams=1):
7778

79+
# Empty the torch CUDA cache before each multimodal test to reduce risk of OOM errors.
80+
if torch.cuda.is_available():
81+
torch.cuda.empty_cache()
82+
7883
world_size = tp_size * pp_size
7984
print("Locate model checkpoints in test storage...")
8085
tllm_model_name, model_ckpt_path = multimodal_model_root

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ l0_h100:
222222
- examples/test_multimodal.py::test_llm_multimodal_general[Phi-3-vision-128k-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
223223
- examples/test_multimodal.py::test_llm_multimodal_general[Phi-3.5-vision-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
224224
- examples/test_multimodal.py::test_llm_multimodal_general[Phi-4-multimodal-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
225-
- examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1]
225+
- examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1]
226226
- examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] # 10 mins
227227
- examples/test_enc_dec.py::test_llm_enc_dec_mmlu[flan-t5-small-float32-tp:1-pp:1-nb:1-enable_fp8] # 7 mins
228228
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-enable_fp8] # 13 mins

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpu
415415
test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5333659)
416416
test_e2e.py::test_ptp_quickstart_advanced[Mixtral-8x7B-NVFP4-nvfp4-quantized/Mixtral-8x7B-Instruct-v0.1] SKIP (https://nvbugs/5333659)
417417
test_e2e.py::test_ptp_quickstart_advanced[Nemotron-Super-49B-v1-NVFP4-nvfp4-quantized/Llama-3_3-Nemotron-Super-49B-v1_nvfp4_hf] SKIP (https://nvbugs/5333659)
418-
examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5331031)
419418
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True] SKIP (https://nvbugs/5322354)
420419
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5322354)
421420
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True] SKIP (https://nvbugs/5336321)

0 commit comments

Comments
 (0)