Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions examples/llm-api/quickstart_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,26 @@
"Describe the scene in the image briefly.",
"",
]
}
},
"multiple_image": {
"media": [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
],
"prompt": ["Describe the difference between the two images."],
},
"mixture_text_image": {
"media": [
[],
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
],
],
"prompt": [
"Who invented the internet?",
"Describe the scene in the image briefly.",
],
},
}


Expand All @@ -66,7 +85,10 @@ def add_multimodal_args(parser):
help="Model type.")
parser.add_argument("--modality",
type=str,
choices=["image", "video", "audio", "image_audio"],
choices=[
"image", "video", "audio", "image_audio",
"multiple_image", "mixture_text_image"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to comment again on a closed PR, but quick question — do we actually need to create/define a new modality (other than image, video etc) here when there are multiple images or videos?

Can we update default loader to accommodate various combinations? [pure_txt, multiple_image with txt, image with txt, etc.]

],
default="image",
help="Media type.")
parser.add_argument("--media",
Expand All @@ -82,6 +104,10 @@ def add_multimodal_args(parser):
choices=["pt", "pil"],
default="pt",
help="The format of the image.")
parser.add_argument("--device",
type=str,
default="cpu",
help="The device to have the input on.")
return parser


Expand Down Expand Up @@ -114,11 +140,6 @@ def parse_arguments():

def main():
args = parse_arguments()
# set prompts and media to example prompts and images if they are not provided
if args.prompt is None:
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
if args.media is None:
args.media = example_medias_and_prompts[args.modality]["media"]

lora_config = None
if args.load_lora:
Expand All @@ -138,7 +159,11 @@ def main():
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}"

device = "cpu"
# set prompts and media to example prompts and images if they are not provided
if args.prompt is None:
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
if args.media is None:
args.media = example_medias_and_prompts[args.modality]["media"]
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
model_dir=llm._hf_model_dir,
model_type=model_type,
Expand All @@ -147,7 +172,7 @@ def main():
media=args.media,
image_data_format=image_format,
num_frames=args.num_frames,
device=device)
device=args.device)

lora_request = None
if args.load_lora:
Expand Down
3 changes: 0 additions & 3 deletions tensorrt_llm/_torch/models/modeling_gemma3vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,6 @@ def forward(
multimodal_param.multimodal_data["image"]["pixel_values"]
for multimodal_param in multimodal_params
]
assert pixel_values == [] or len(
pixel_values
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"

mm_embeds = []
mm_token_mask = None
Expand Down
6 changes: 1 addition & 5 deletions tensorrt_llm/_torch/models/modeling_hyperclovax.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def _post_process(self,
input_ids: torch.Tensor,
preprocessed_image: dict[str, any] = None):
if not preprocessed_image:
return input_ids
return input_ids[0]

vision_query_lengths = preprocessed_image.get("vision_query_lengths",
None)
Expand Down Expand Up @@ -659,7 +659,6 @@ def _preprocess(self, text_prompt: dict[str, any], images: List[Any],
mm_processor_kwargs: Dict[str, Any]):

preprocessed_image = None
is_video_list = [False] * len(images)
if images is not None:
is_video_list = [False] * len(images)
preprocessed_image = self.processor(
Expand Down Expand Up @@ -1026,9 +1025,6 @@ def forward(
multimodal_params = kwargs.get("multimodal_params", [])
mm_embeds = []
if len(multimodal_params) > 0:
assert len(multimodal_params) == num_context_requests == len(
multimodal_params
), f"Number of multimodal tensors ({len(multimodal_params)}) should be equal to number of context requests ({num_context_requests}) in the batch."
if not DISAGG:
mm_embeds = self.mm_encoder.forward(multimodal_params)
else:
Expand Down
9 changes: 5 additions & 4 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,13 +1060,14 @@ def forward(
**kwargs,
) -> torch.Tensor:
multimodal_params = kwargs.get("multimodal_params", [])
if multimodal_params:
mm_embed = [
mm_embeds = []
if len(multimodal_params) > 0:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
input_ids, inputs_embeds = fuse_input_embeds(
self.model.embed_tokens, input_ids, mm_embed)
input_ids, inputs_embeds = fuse_input_embeds(self.model.embed_tokens,
input_ids, mm_embeds)
return super().forward(attn_metadata,
input_ids,
position_ids,
Expand Down
19 changes: 10 additions & 9 deletions tensorrt_llm/_torch/models/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,13 @@ def __call__(
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
text_prompt, mm_data = inputs.get("prompt"), inputs.get(
"multi_modal_data", {})
assert 'image' in mm_data

input_ids = self.tokenizer(
text_prompt, return_tensors="pt").input_ids[0].to(self.device)

if not mm_data:
return input_ids.to(torch.int32).tolist(), {}

mm_tensor = self._preprocess(mm_data['image'])
mm_features = torch.stack(
[self._process(tensor) for tensor in mm_tensor])
Expand Down Expand Up @@ -274,16 +276,15 @@ def forward(
logger.debug(f"{num_context_requests=}, {num_generation_requests=}")

multimodal_params = kwargs.get("multimodal_params", [])
mm_embed = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
assert mm_embed == [] or len(
mm_embed
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"
mm_embeds = []
if len(multimodal_params) > 0:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]

input_ids, inputs_embeds = fuse_input_embeds(
self.llm.model.embed_tokens, input_ids, mm_embed)
self.llm.model.embed_tokens, input_ids, mm_embeds)
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
inputs_embeds, return_context_logits)
return logits
10 changes: 3 additions & 7 deletions tensorrt_llm/_torch/models/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,9 @@ def forward(
logger.debug(f"{num_context_requests=}, {num_generation_requests=}")

multimodal_params = kwargs.get("multimodal_params", [])
image_features = []
mm_embeds = []
multimodal_params_len = len(multimodal_params)
if multimodal_params_len > 0:
if multimodal_params_len != num_context_requests:
raise RuntimeError(
f"Number of multimodal tensors ({multimodal_params_len}) should be equal to number of "
f"context requests ({num_context_requests}) in the batch.")
pixel_values = [
x.multimodal_data["image"]["pixel_values"]
for x in multimodal_params
Expand All @@ -377,15 +373,15 @@ def forward(
f"({multimodal_params_len}).")
batched_pixel_values, batched_image_sizes = self._batch_pixel_values(
pixel_values=pixel_values, image_sizes=image_sizes)
image_features = [
mm_embeds = [
self._get_image_features(pixel_values=batched_pixel_values,
image_sizes=batched_image_sizes)
]

input_ids, inputs_embeds = fuse_input_embeds(
embedding_layer=self.llm.model.embed_tokens,
input_ids=input_ids,
mm_embeds=image_features,
mm_embeds=mm_embeds,
mm_token_ids=self._image_token_ids,
)

Expand Down
12 changes: 7 additions & 5 deletions tensorrt_llm/_torch/models/modeling_phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,16 @@ def forward(
)

multimodal_params = kwargs.get("multimodal_params", [])
mm_embedding = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
mm_embeds = []
if len(multimodal_params) > 0:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
input_ids, input_embeds = fuse_input_embeds(
self.llm.model.embed_tokens,
input_ids,
mm_embedding,
mm_embeds,
mm_token_ids=self.MM_TOKEN_IDS,
)

Expand Down
19 changes: 10 additions & 9 deletions tensorrt_llm/_torch/models/modeling_vila.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,9 @@ def __call__(
input_ids = self.tokenizer(
text_prompt, return_tensors="pt").input_ids[0].to(self.device)

if not mm_data:
return input_ids.to(torch.int32).tolist(), {}

mm_tensor, block_sizes = self._preprocess(
mm_data, mm_processor_kwargs, use_fast=True
) # use_fast uses Pytorch GPU preprocessing, otherwise uses PIL CPU preprocessing
Expand Down Expand Up @@ -1164,17 +1167,15 @@ def forward(

num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
multimodal_params = kwargs.get("multimodal_params", [])
mm_embed = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]

assert mm_embed == [] or len(
mm_embed
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"
mm_embeds = []
if len(multimodal_params) > 0:
mm_embeds = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]

input_ids, inputs_embeds = fuse_input_embeds(
self.llm.model.embed_tokens, input_ids, mm_embed)
self.llm.model.embed_tokens, input_ids, mm_embeds)
logits = self.llm.forward(attn_metadata=attn_metadata,
input_ids=input_ids,
position_ids=position_ids,
Expand Down
35 changes: 22 additions & 13 deletions tensorrt_llm/inputs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,9 @@ def convert_to_conversation_message(prompt: str, media: Union[str,
modality: str) -> ConversationMessage:
if isinstance(media, str):
media = [media]
if modality == "image":
if modality in ["image", "multiple_image"]:
mm_data = [
MultimodalData(modality=modality,
MultimodalData(modality="image",
data=load_image(i,
format=image_data_format,
device=device)) for i in media
Expand Down Expand Up @@ -530,6 +530,15 @@ def convert_to_conversation_message(prompt: str, media: Union[str,
if _modal is None:
raise ValueError(f"Unknown matching modality: {modality}")
mm_data.append(MultimodalData(modality=_modal, data=data))
elif modality == "mixture_text_image":
mm_data = []
for m in media:
if m:
mm_data.append(
MultimodalData(modality="image",
data=load_image(m,
format=image_data_format,
device=device)))
else:
raise ValueError(f"Unknown modality: {modality}")
return ConversationMessage(role="user", content=prompt, media=mm_data)
Expand Down Expand Up @@ -561,16 +570,16 @@ def convert_to_conversation_message(prompt: str, media: Union[str,
if mm_placeholder_counts:
conv["content"] = add_multimodal_placeholders(
model_type, conv["content"], mm_placeholder_counts)
prompt = apply_chat_template(
model_type=model_type,
tokenizer=tokenizer,
processor=processor,
conversation=[conv],
add_generation_prompt=True,
mm_placeholder_counts=mm_placeholder_counts)
inputs.append({
"prompt": prompt,
"multi_modal_data": mm_data_tracker.retrieve_all_sync()
})
prompt = apply_chat_template(
model_type=model_type,
tokenizer=tokenizer,
processor=processor,
conversation=[conv],
add_generation_prompt=True,
mm_placeholder_counts=mm_placeholder_counts)
input = {"prompt": prompt}
if mm_placeholder_counts:
input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync()
inputs.append(input)

return inputs
15 changes: 14 additions & 1 deletion tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1939,7 +1939,7 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv):


@pytest.mark.parametrize("use_cuda_graph", [False, True])
@pytest.mark.parametrize("modality", ["image", "video"])
@pytest.mark.parametrize("modality", ["image", "video", "mixture_text_image"])
@pytest.mark.parametrize("model_name,model_path", [
("NVILA-8B-FP16", "vila/NVILA-8B"),
("NVILA-15B-FP16", "NVILA-15B"),
Expand Down Expand Up @@ -1987,6 +1987,16 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
str(test_data_root / "world.mp4"),
],
},
"mixture_text_image": {
"prompt": [
"Who invented the internet?",
"Describe the scene in the image briefly.",
],
"media": [
[],
[str(test_data_root / "inpaint.png")],
],
}
}

expected_keywords = {
Expand Down Expand Up @@ -2042,6 +2052,9 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
["scenic", "rock", "landscape", "snow", "altitude"],
["highway", "traffic", "directions", "lanes", "Jurong"],
],
"mixture_text_image":
[["invention", "person", "scientists", "Lick", "engineers"],
["landscape", "dome", "yosemite", "altitude", "scattered"]]
},
"gemma-3-27b-it": {
"image": [
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/qa/examples_test_list.txt
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B
test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-True]
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False]
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-False]
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[audio]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/qa/llm_sanity_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert
test_e2e.py::test_ptp_quickstart_multimodal[llava-v1.6-mistral-7b-llava-v1.6-mistral-7b-hf-image-False]
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False]
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image-False]
test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-video-False]
test_e2e.py::test_ptp_quickstart_multimodal[qwen2-vl-7b-instruct-Qwen2-VL-7B-Instruct-image-False]
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
- condition:
ranges:
system_gpu_count:
Expand Down