From 32720996d1956ec3e3644e07d5134dd839c4db79 Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Fri, 18 Jul 2025 10:13:28 -0700 Subject: [PATCH 01/13] Enable mm_emb as input for llama4 --- examples/pytorch/quickstart_multimodal.py | 115 +++++++++++++++++++ tensorrt_llm/_torch/models/modeling_llama.py | 99 +++++++++++++++- tensorrt_llm/inputs/utils.py | 55 ++++++--- tensorrt_llm/llmapi/llm.py | 8 +- 4 files changed, 259 insertions(+), 18 deletions(-) create mode 100644 examples/pytorch/quickstart_multimodal.py diff --git a/examples/pytorch/quickstart_multimodal.py b/examples/pytorch/quickstart_multimodal.py new file mode 100644 index 00000000000..8a4635b8611 --- /dev/null +++ b/examples/pytorch/quickstart_multimodal.py @@ -0,0 +1,115 @@ +import argparse +import json +import os + +from quickstart_advanced import add_llm_args, setup_llm + +from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS, + default_multimodal_input_loader) + +example_images = [ + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png", + "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg", +] +example_image_prompts = [ + "Describe the natural environment in the image.", + "Describe the object and the weather condition in the image.", + "Describe the traffic condition on the road in the image.", +] +example_videos = [ + "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4", + "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4", +] +example_video_prompts = [ + "Tell me what you see in the video briefly.", + "Describe the scene in the video briefly.", +] + + +def add_multimodal_args(parser): + parser.add_argument("--model_type", + type=str, + choices=ALL_SUPPORTED_MULTIMODAL_MODELS, + help="Model type.") + parser.add_argument("--modality", + type=str, + choices=["image", "video"], + default="image", + help="Media type.") + parser.add_argument("--media", + type=str, + nargs="+", + help="A single or a list of media filepaths / urls.") + parser.add_argument("--num_frames", + type=int, + default=8, + help="The number of video frames to be sampled.") + parser.add_argument("--image_format", + type=str, + choices=["pt", "pil"], + default="pt", + help="The format of the image.") + return parser + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Multimodal models with the PyTorch workflow.") + parser = add_llm_args(parser) + parser = add_multimodal_args(parser) + args = parser.parse_args() + + args.disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite + if args.kv_cache_fraction is None: + args.kv_cache_fraction = 0.6 # lower the default kv cache fraction for multimodal + + return args + + +def main(): + args = parse_arguments() + # set prompts and media to example prompts and images if they are not provided + if args.prompt is None: + args.prompt = example_image_prompts if args.modality == "image" else example_video_prompts + if args.media is None: + args.media = example_images if args.modality == "image" else example_videos + + llm, sampling_params = setup_llm(args) + + image_format = args.image_format + if args.model_type is not None: + model_type = args.model_type + else: + model_type = json.load( + open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type'] + assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}" + + device = "cuda" + import torch + mm_embeds_1 = torch.load("maverick_mm_embed_seashore_v2.pt") + mm_embeds_2 = torch.load("maverick_mm_embed_inpaint_v2.pt") + mm_embeds_3 = torch.load("maverick_mm_embed_highway_61_v2.pt") + mm_embeds = [mm_embeds_1, mm_embeds_2, mm_embeds_3] + + inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer, + model_dir=llm._hf_model_dir, + model_type=model_type, + modality=args.modality, + prompts=args.prompt, + media=args.media, + image_data_format=image_format, + num_frames=args.num_frames, + mm_embeddings=mm_embeds, + device=device) + + outputs = llm.generate(inputs, sampling_params) + + for i, output in enumerate(outputs): + prompt = args.prompt[i] + generated_text = output.outputs[0].text + print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + main() diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 33dddfc784c..39d03a6c958 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -1,5 +1,6 @@ import copy -from typing import Dict, List, Optional, Tuple, Union +import os +from typing import Dict, List, Optional, Tuple, Union, Any import torch from PIL.Image import Image @@ -807,6 +808,11 @@ def __init__(self, self.tokenizer = tokenizer self.vocab_size = model_config.text_config.vocab_size self.image_token_index = model_config.image_token_index + # TODO: use tokenizer to get these special tokens/values + self.fake_image_token = "<|image|>" # this must be the same as placeholder prompt + self.image_token = "<|patch|>" + self.image_token_start_index = 200080 + self.image_token_end_index = 200081 self.encoder = nn.ModuleDict({ "vision_model": @@ -816,6 +822,97 @@ def __init__(self, }).cuda() load_sharded_checkpoint(self.encoder, model_path, strict=False) + def postprocess(self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[Dict[str, Any]]]) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + """ + Post-process multimodal embeddings for Llama4 model. + + Args: + inputs: Text prompt containing image placeholders + multimodal_embedding: Dictionary containing image embedding data with special token information + + Returns: + Tuple of (token_ids, extra_processed_inputs) where: + - token_ids: List of processed token IDs + - extra_processed_inputs: Optional dictionary containing multimodal embeddings + """ + text_prompt = inputs.get("prompt") + if not text_prompt: + raise ValueError("Text prompt is required but not provided") + + # Validate multimodal embedding structure + if not isinstance(multimodal_embedding, dict): + raise ValueError("multimodal_embedding must be a dictionary") + + if 'image' not in multimodal_embedding: + raise ValueError("Only image modality is supported for now") + + mm_embedding_info = multimodal_embedding['image'] + if not mm_embedding_info or not isinstance(mm_embedding_info[0], dict): + raise ValueError("Llama4 image embedding must contain special token information") + + # Extract embedding components + try: + mm_embeddings = [mm_embedding['mm_embeddings'] for mm_embedding in mm_embedding_info] + mm_embedding_special_tokens = [mm_embedding['image_special_tokens'] for mm_embedding in mm_embedding_info] + mm_embedding_special_offsets = [mm_embedding['image_special_token_offsets'] for mm_embedding in mm_embedding_info] + except KeyError as e: + raise ValueError(f"Missing required key in multimodal embedding: {e}") + + # Validate embedding dimensions + model_hidden_size = self.model_config.text_config.hidden_size + for i, embedding in enumerate(mm_embeddings): + if embedding.shape[-1] != model_hidden_size: + raise ValueError( + f"Multimodal embedding {i} hidden size {embedding.shape[-1]} " + f"must match model hidden size {model_hidden_size}" + ) + + # Count image placeholders (number of images) in the prompt + total_placeholders = text_prompt.count(self.fake_image_token) + if total_placeholders == 0: + raise ValueError("No image placeholders found in the prompt, but multimodal embedding was provided") + + if total_placeholders != len(mm_embeddings): + raise ValueError( + f"Number of image placeholders ({total_placeholders}) " + f"does not match number of embeddings ({len(mm_embeddings)})" + ) + + # Process prompt with image embeddings + prompt_splits = text_prompt.split(self.fake_image_token) + new_prompt_parts = [] + + for local_image_index, split_part in enumerate(prompt_splits): + new_prompt_parts.append(split_part) + + if local_image_index < total_placeholders: + # Calculate total tokens for this image + num_tokens = len(mm_embeddings[local_image_index]) + len(mm_embedding_special_tokens[local_image_index]) + + # Create image token sequence + image_tokens = [self.image_token] * num_tokens + + # Replace special tokens with actual decoded tokens + for offset, token_id in zip(mm_embedding_special_offsets[local_image_index], + mm_embedding_special_tokens[local_image_index]): + if offset < len(image_tokens): + image_tokens[offset] = self.tokenizer.decode([token_id]) + + # Join tokens without spaces + image_str = "".join(image_tokens) + new_prompt_parts.append(image_str) + + # Combine all parts and tokenize + processed_text = "".join(new_prompt_parts) + text_inputs = self.tokenizer(processed_text, return_tensors="pt", add_special_tokens=False) + token_ids = text_inputs.input_ids.squeeze() + + # Replace image token indices with out-of-vocabulary tokens + token_ids[token_ids == self.image_token_index] = self.vocab_size + 1 + # Concatenate all multimodal embeddings + mm_embeds = torch.cat(mm_embeddings, dim=0) + return token_ids.tolist(), {"mm_embedding": mm_embeds} + @torch.inference_mode() def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index a4bf8570d0a..621489466e0 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -322,7 +322,7 @@ class ConversationMessage(TypedDict): """Type definition for conversation message structure.""" role: str content: List[dict[str, Any]] - media: List[MultimodalData] + media: List[MultimodalData] | List[torch.Tensor] | List[Dict[str, Any]] # @classmethod # def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage": @@ -480,20 +480,29 @@ def default_multimodal_input_loader( media: Union[List[str], List[List[str]]], image_data_format: str = "pt", num_frames: int = 8, + mm_embeddings: Optional[Union[List[torch.Tensor], List[Dict[str, Any]]]] = None, device: str = "cpu") -> List[dict[str, Union[str, torch.Tensor]]]: def convert_to_conversation_message(prompt: str, media: Union[str, List[str]], - modality: str) -> ConversationMessage: + modality: str, + mm_embedding: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> ConversationMessage: if isinstance(media, str): media = [media] if modality == "image": - mm_data = [ - MultimodalData(modality=modality, - data=load_image(i, - format=image_data_format, - device=device)) for i in media - ] + if mm_embedding is not None: + # each mm_embedding corresponds to each image placeholder + if not isinstance(mm_embedding, list): + mm_embedding = [mm_embedding] + + mm_data = [{'modality': modality, 'mm_embedding_info': mm} for mm in mm_embedding] + else: + mm_data = [ + MultimodalData(modality=modality, + data=load_image(i, + format=image_data_format, + device=device)) for i in media + ] elif modality == "video": mm_data = [ MultimodalData(modality=modality, @@ -551,11 +560,20 @@ def convert_to_conversation_message(prompt: str, media: Union[str, trust_remote_code=True) inputs = [] - for prompt, media in zip(prompts, media): - conv = convert_to_conversation_message(prompt, media, modality) + for prompt_idx, (prompt, media) in enumerate(zip(prompts, media)): + if mm_embeddings is not None: + mm_embedding = mm_embeddings[prompt_idx] + conv = convert_to_conversation_message(prompt, media, modality, mm_embedding) + else: + conv = convert_to_conversation_message(prompt, media, modality) mm_data_tracker = MultimodalDataTracker(model_type) for mdata in conv["media"]: - mm_data_tracker.add_data(mdata["modality"], mdata["data"]) + # Check if mdata is a MultimodalData + if isinstance(mdata, dict) and "modality" in mdata and "data" in mdata: + mm_data_tracker.add_data(mdata["modality"], mdata["data"]) + else: + # Add embeddings to the tracker for placeholder handling + mm_data_tracker.add_data(mdata["modality"], mdata["mm_embedding_info"]) mm_placeholder_counts = mm_data_tracker.placeholder_counts() prompt = conv["content"] if mm_placeholder_counts: @@ -568,9 +586,16 @@ def convert_to_conversation_message(prompt: str, media: Union[str, conversation=[conv], add_generation_prompt=True, mm_placeholder_counts=mm_placeholder_counts) - inputs.append({ - "prompt": prompt, - "multi_modal_data": mm_data_tracker.retrieve_all_sync() - }) + + if mm_embeddings is not None: + inputs.append({ + "prompt": prompt, + "multi_modal_embeddings": mm_data_tracker.retrieve_all_sync() + }) + else: + inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data_tracker.retrieve_all_sync() + }) return inputs diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index dcf3ca92902..b889c0ac870 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -342,8 +342,9 @@ def generate_async( inputs = prompt_inputs(inputs) if not inputs.get("prompt") and inputs.get( - "prompt_token_ids") and inputs.get( - "multi_modal_data") and not isinstance( + "prompt_token_ids") and (inputs.get( + "multi_modal_data") or inputs.get( + "multi_modal_embeddings")) and not isinstance( self.input_processor, DefaultInputProcessor): # VLMs need to process/tokenize the prompt in their own way prompt = self.tokenizer.decode(inputs['prompt_token_ids']) @@ -377,6 +378,9 @@ def generate_async( with nvtx_range_debug("input_processor_with_hash"): prompt_token_ids, extra_processed_inputs = input_processor_with_hash( inputs, sampling_params) + elif 'multi_modal_embeddings' in inputs: + mm_embedding_info = inputs['multi_modal_embeddings'] + prompt_token_ids, extra_processed_inputs = self.input_processor.postprocess(inputs, mm_embedding_info) else: with nvtx_range_debug("input_processor"): prompt_token_ids, extra_processed_inputs = self.input_processor( From 44674365414e09ae5a1bf0c9d3d999ad79db22ab Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Fri, 18 Jul 2025 12:53:57 -0700 Subject: [PATCH 02/13] Minor update --- examples/pytorch/quickstart_multimodal.py | 6 +++--- tensorrt_llm/_torch/models/modeling_llama.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/quickstart_multimodal.py b/examples/pytorch/quickstart_multimodal.py index 8a4635b8611..75d0eba184b 100644 --- a/examples/pytorch/quickstart_multimodal.py +++ b/examples/pytorch/quickstart_multimodal.py @@ -87,9 +87,9 @@ def main(): device = "cuda" import torch - mm_embeds_1 = torch.load("maverick_mm_embed_seashore_v2.pt") - mm_embeds_2 = torch.load("maverick_mm_embed_inpaint_v2.pt") - mm_embeds_3 = torch.load("maverick_mm_embed_highway_61_v2.pt") + mm_embeds_1 = torch.load("scout_mm_embed_seashore_v2.pt") + mm_embeds_2 = torch.load("scout_mm_embed_inpaint_v2.pt") + mm_embeds_3 = torch.load("scout_mm_embed_highway_61_v2.pt") mm_embeds = [mm_embeds_1, mm_embeds_2, mm_embeds_3] inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer, diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 39d03a6c958..885cdcc481e 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -828,8 +828,8 @@ def postprocess(self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[D Args: inputs: Text prompt containing image placeholders - multimodal_embedding: Dictionary containing image embedding data with special token information - + multimodal_embedding: Dictionary containing image embedding data with special token information. + Consider adding metadata fields (e.g., model_type, model_name, version) for validation. Returns: Tuple of (token_ids, extra_processed_inputs) where: - token_ids: List of processed token IDs @@ -904,7 +904,8 @@ def postprocess(self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[D # Combine all parts and tokenize processed_text = "".join(new_prompt_parts) - text_inputs = self.tokenizer(processed_text, return_tensors="pt", add_special_tokens=False) + # TODO: pass sampling_params.add_special_tokens to tokenizer + text_inputs = self.tokenizer(processed_text, return_tensors="pt", add_special_tokens=True) token_ids = text_inputs.input_ids.squeeze() # Replace image token indices with out-of-vocabulary tokens From f7e9188aeb029ee1372eec1fbdc65762f188a8cd Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Mon, 21 Jul 2025 16:20:00 -0700 Subject: [PATCH 03/13] Update to use shared tensor for main-worker comm. --- tensorrt_llm/_torch/models/modeling_llama.py | 25 ++-- .../_torch/models/modeling_llava_next.py | 30 ++++- tensorrt_llm/executor/worker.py | 2 + tensorrt_llm/inputs/multimodal.py | 127 ++++++++++++++++++ tensorrt_llm/llmapi/llm.py | 4 +- 5 files changed, 177 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 885cdcc481e..3c889942e3d 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -822,24 +822,27 @@ def __init__(self, }).cuda() load_sharded_checkpoint(self.encoder, model_path, strict=False) - def postprocess(self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[Dict[str, Any]]]) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + def attch_multimodal_embeddings(self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[Dict[str, Any]]], sampling_params: SamplingParams) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: """ - Post-process multimodal embeddings for Llama4 model. + Attach pre-processed multimodal embeddings into text token stream for Llama4 model. + + This method skips vision processing and works with externally provided embeddings. + It replaces/expands image placeholders in the text with appropriate tokens and prepares + the embeddings for model forward pass. Args: inputs: Text prompt containing image placeholders - multimodal_embedding: Dictionary containing image embedding data with special token information. + multimodal_embedding: Dictionary containing pre-processed image embedding data with special token information. Consider adding metadata fields (e.g., model_type, model_name, version) for validation. Returns: Tuple of (token_ids, extra_processed_inputs) where: - - token_ids: List of processed token IDs + - token_ids: List of processed token IDs with image placeholders - extra_processed_inputs: Optional dictionary containing multimodal embeddings """ text_prompt = inputs.get("prompt") if not text_prompt: raise ValueError("Text prompt is required but not provided") - # Validate multimodal embedding structure if not isinstance(multimodal_embedding, dict): raise ValueError("multimodal_embedding must be a dictionary") @@ -904,15 +907,19 @@ def postprocess(self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[D # Combine all parts and tokenize processed_text = "".join(new_prompt_parts) - # TODO: pass sampling_params.add_special_tokens to tokenizer - text_inputs = self.tokenizer(processed_text, return_tensors="pt", add_special_tokens=True) + kwargs = {} + if sampling_params.truncate_prompt_tokens is not None: + kwargs = dict(truncation=True, + max_length=sampling_params.truncate_prompt_tokens) + text_inputs = self.tokenizer(processed_text, add_special_tokens=sampling_params.add_special_tokens, **kwargs) token_ids = text_inputs.input_ids.squeeze() # Replace image token indices with out-of-vocabulary tokens token_ids[token_ids == self.image_token_index] = self.vocab_size + 1 # Concatenate all multimodal embeddings - mm_embeds = torch.cat(mm_embeddings, dim=0) - return token_ids.tolist(), {"mm_embedding": mm_embeds} + multimodal_data = {} + multimodal_data["multimodal_embedding"] = torch.cat(mm_embeddings, dim=0) + return token_ids.tolist(), {"multimodal_data": multimodal_data} @torch.inference_mode() def __call__( diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 8af484ce1ab..74a26dabe95 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -1,6 +1,6 @@ import copy import os -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict import torch import torch.nn as nn @@ -195,6 +195,34 @@ def _postprocess(self, input_ids, mm_features): mm_features = mm_features.view(-1, mm_features.shape[-1]) return fused_input_ids, mm_features + def attch_multimodal_embeddings(self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[torch.Tensor]], sampling_params: SamplingParams) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + """ + Attach pre-processed multimodal embeddings into text token stream for Llama4 model. + + This method skips vision processing and works with externally provided embeddings. + It replaces/expands image placeholders in the text with appropriate tokens and prepares + the embeddings for model forward pass. + + Args: + inputs: Text prompt containing image placeholders + multimodal_embedding: Dictionary containing pre-processed image embedding data + Returns: + Tuple of (token_ids, extra_processed_inputs) where: + - token_ids: List of processed token IDs with image placeholders + - extra_processed_inputs: Optional dictionary containing multimodal embeddings + """ + text_prompt = inputs.get("prompt") + assert 'image' in multimodal_embedding + + input_ids = self.tokenizer( + text_prompt, return_tensors="pt").input_ids[0].to(self.device) + fused_input_ids, mm_features = self._postprocess(input_ids, multimodal_embedding['image']) + multimodal_data = {} + multimodal_data["multimodal_embedding"] = mm_features + return fused_input_ids.to(torch.int32).tolist(), { + "multimodal_data": multimodal_data + } + @torch.inference_mode() def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 6ebd7adc03d..b4b81fa8653 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -500,6 +500,8 @@ def _deduce_max_tokens(request: GenerationRequest, if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: + # Convert back to tensor, as opposite to `to_handle` in `llm.generate_async` + request.multimodal_params.to_tensor("multimodal_data", key="multimodal_embedding") executor_request.py_multimodal_data = request.multimodal_params.multimodal_data if self._is_pytorch_backend and request.sampling_params.logits_processor: diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index 19d55ae7744..f230e9954b3 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -234,6 +234,133 @@ def _to_device( f"MultimodalParams: Unsupported element '{element}' to move to device. " f"Supported elements: 'multimodal_data', 'multimodal_input'") + def to_handle(self, element: str, key: Optional[str] = "multimodal_embedding") -> None: + """Convert multimodal data to tensor handle. + + Converts torch.Tensor objects to SharedTensorContainer handles (serializable dictionaries) + for efficient IPC. This function is a in-place operation. + + Args: + element: Element to convert ("multimodal_data" or "multimodal_input") + key: Specific key to convert. If None, converts all tensor values in multimodal_data. + Defaults to "multimodal_embedding". + + Example: + # Convert all tensors in multimodal_data to handles + params.to_handle("multimodal_data", key=None) + + # Convert only multimodal_embedding section tensors to handles + params.to_handle("multimodal_data", key="multimodal_embedding") + """ + # Lazy import to avoid circular dependency + from tensorrt_llm._torch.shared_tensor import SharedTensorContainer + + def _to_tensor_handle(data): + for k, v in data.items(): + if isinstance(v, torch.Tensor): + # Convert tensor to handle + handle = SharedTensorContainer.from_tensor(v).dump_to_dict() + data[k] = handle + elif isinstance(v, dict): + _to_tensor_handle(v) + elif isinstance(v, list): + for i, item in enumerate(v): + if isinstance(item, torch.Tensor): + handle = SharedTensorContainer.from_tensor(item).dump_to_dict() + v[i] = handle + + if element == "multimodal_data": + if self.multimodal_data is None: + return + if key is None: + _to_tensor_handle(self.multimodal_data) + else: + if key not in self.multimodal_data: + raise ValueError(f"Key '{key}' not found in multimodal_data") + + value = self.multimodal_data[key] + if isinstance(value, torch.Tensor): + handle = SharedTensorContainer.from_tensor(value).dump_to_dict() + self.multimodal_data[key] = handle + elif isinstance(value, dict): + _to_tensor_handle(value) + else: + raise ValueError(f"Unsupported value type for multimodal_data: {type(value)}") + elif element == "multimodal_input": + # No-op for multimodal_input + return + else: + raise ValueError(f"Unsupported element '{element}' to convert to handle.") + + def to_tensor(self, element: str, key: Optional[str] = "multimodal_embedding") -> None: + """Convert multimodal tensor handles back to tensors. This is the dual operation to to_handle. + + Converts SharedTensorContainer handles (serializable dictionaries) back to torch.Tensor objects + for local computation. This function performs in-place modifications to the multimodal_data. + + Args: + element: Element to convert ("multimodal_data" or "multimodal_input") + key: Specific key to convert. If None, converts all tensor handles in multimodal_data. + Defaults to "multimodal_embedding". + + Example: + # Convert all handles back to tensors + params.to_tensor("multimodal_data", key=None) + + # Convert only multimodal_embedding section handles back to tensors + params.to_tensor("multimodal_data", key="multimodal_embedding") + """ + # Lazy import to avoid circular dependency + from tensorrt_llm._torch.shared_tensor import SharedTensorContainer + + def _to_tensor(data): + for k, v in data.items(): + if isinstance(v, dict) and 'method_key' in v: + # This is a tensor handle (dict with method_key) + try: + tensor = SharedTensorContainer.from_dict(v).get_local_view() + data[k] = tensor + except Exception as e: + raise ValueError(f"Failed to convert handle to tensor for key '{k}': {e}") + elif isinstance(v, dict): + _to_tensor(v) + elif isinstance(v, list): + for i, item in enumerate(v): + if isinstance(item, dict) and 'method_key' in item: + try: + tensor = SharedTensorContainer.from_dict(item).get_local_view() + v[i] = tensor + except Exception as e: + raise ValueError(f"Failed to convert handle to tensor in list at index {i}: {e}") + + if element == "multimodal_data": + if self.multimodal_data is None: + return + + if key is None: + _to_tensor(self.multimodal_data) + else: + if key not in self.multimodal_data: + raise ValueError(f"Key '{key}' not found in multimodal_data") + + value = self.multimodal_data[key] + if isinstance(value, dict) and 'method_key' in value: # This is a tensor handle + try: + tensor = SharedTensorContainer.from_dict(value).get_local_view() + self.multimodal_data[key] = tensor + except Exception as e: + raise ValueError(f"Failed to convert handle to tensor for key '{key}': {e}") + elif isinstance(value, dict): + _to_tensor(value) + else: + raise ValueError(f"Unsupported value type for multimodal_data: {type(value)}") + + elif element == "multimodal_input": + # No-op for multimodal_input + return + else: + raise ValueError(f"Unsupported element '{element}' to convert to tensor.") + def strip_for_context(self) -> None: """Strip multimodal data for context processing. diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index b889c0ac870..f3841361e97 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -380,7 +380,7 @@ def generate_async( inputs, sampling_params) elif 'multi_modal_embeddings' in inputs: mm_embedding_info = inputs['multi_modal_embeddings'] - prompt_token_ids, extra_processed_inputs = self.input_processor.postprocess(inputs, mm_embedding_info) + prompt_token_ids, extra_processed_inputs = self.input_processor.attch_multimodal_embeddings(inputs, mm_embedding_info, sampling_params) else: with nvtx_range_debug("input_processor"): prompt_token_ids, extra_processed_inputs = self.input_processor( @@ -394,6 +394,8 @@ def generate_async( 'multimodal_input'), multimodal_data=extra_processed_inputs.get( 'multimodal_data')) + # Convert to shared tensor handle to reduce IPC overhead + multimodal_params.to_handle("multimodal_data", key="multimodal_embedding") # Only pass it if it has content if not multimodal_params.has_content(): multimodal_params = None From d5d81f33ab3f5c6a0357e2da92b48ad89aff8920 Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Mon, 21 Jul 2025 16:30:47 -0700 Subject: [PATCH 04/13] Format --- tensorrt_llm/_torch/models/modeling_llama.py | 59 +++++++++++++------ .../_torch/models/modeling_llava_next.py | 11 +++- tensorrt_llm/executor/worker.py | 3 +- tensorrt_llm/inputs/multimodal.py | 59 +++++++++++++------ tensorrt_llm/inputs/utils.py | 46 ++++++++++----- tensorrt_llm/llmapi/llm.py | 15 ++--- 6 files changed, 130 insertions(+), 63 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 3c889942e3d..8a898d738f3 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -1,6 +1,5 @@ import copy -import os -from typing import Dict, List, Optional, Tuple, Union, Any +from typing import Any, Dict, List, Optional, Tuple, Union import torch from PIL.Image import Image @@ -809,7 +808,7 @@ def __init__(self, self.vocab_size = model_config.text_config.vocab_size self.image_token_index = model_config.image_token_index # TODO: use tokenizer to get these special tokens/values - self.fake_image_token = "<|image|>" # this must be the same as placeholder prompt + self.fake_image_token = "<|image|>" # this must be the same as placeholder prompt self.image_token = "<|patch|>" self.image_token_start_index = 200080 self.image_token_end_index = 200081 @@ -822,7 +821,12 @@ def __init__(self, }).cuda() load_sharded_checkpoint(self.encoder, model_path, strict=False) - def attch_multimodal_embeddings(self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[Dict[str, Any]]], sampling_params: SamplingParams) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + def attch_multimodal_embeddings( + self, inputs: TextPrompt, multimodal_embedding: Dict[str, + List[Dict[str, + Any]]], + sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: """ Attach pre-processed multimodal embeddings into text token stream for Llama4 model. @@ -851,15 +855,26 @@ def attch_multimodal_embeddings(self, inputs: TextPrompt, multimodal_embedding: mm_embedding_info = multimodal_embedding['image'] if not mm_embedding_info or not isinstance(mm_embedding_info[0], dict): - raise ValueError("Llama4 image embedding must contain special token information") + raise ValueError( + "Llama4 image embedding must contain special token information") # Extract embedding components try: - mm_embeddings = [mm_embedding['mm_embeddings'] for mm_embedding in mm_embedding_info] - mm_embedding_special_tokens = [mm_embedding['image_special_tokens'] for mm_embedding in mm_embedding_info] - mm_embedding_special_offsets = [mm_embedding['image_special_token_offsets'] for mm_embedding in mm_embedding_info] + mm_embeddings = [ + mm_embedding['mm_embeddings'] + for mm_embedding in mm_embedding_info + ] + mm_embedding_special_tokens = [ + mm_embedding['image_special_tokens'] + for mm_embedding in mm_embedding_info + ] + mm_embedding_special_offsets = [ + mm_embedding['image_special_token_offsets'] + for mm_embedding in mm_embedding_info + ] except KeyError as e: - raise ValueError(f"Missing required key in multimodal embedding: {e}") + raise ValueError( + f"Missing required key in multimodal embedding: {e}") # Validate embedding dimensions model_hidden_size = self.model_config.text_config.hidden_size @@ -867,19 +882,19 @@ def attch_multimodal_embeddings(self, inputs: TextPrompt, multimodal_embedding: if embedding.shape[-1] != model_hidden_size: raise ValueError( f"Multimodal embedding {i} hidden size {embedding.shape[-1]} " - f"must match model hidden size {model_hidden_size}" - ) + f"must match model hidden size {model_hidden_size}") # Count image placeholders (number of images) in the prompt total_placeholders = text_prompt.count(self.fake_image_token) if total_placeholders == 0: - raise ValueError("No image placeholders found in the prompt, but multimodal embedding was provided") + raise ValueError( + "No image placeholders found in the prompt, but multimodal embedding was provided" + ) if total_placeholders != len(mm_embeddings): raise ValueError( f"Number of image placeholders ({total_placeholders}) " - f"does not match number of embeddings ({len(mm_embeddings)})" - ) + f"does not match number of embeddings ({len(mm_embeddings)})") # Process prompt with image embeddings prompt_splits = text_prompt.split(self.fake_image_token) @@ -890,14 +905,16 @@ def attch_multimodal_embeddings(self, inputs: TextPrompt, multimodal_embedding: if local_image_index < total_placeholders: # Calculate total tokens for this image - num_tokens = len(mm_embeddings[local_image_index]) + len(mm_embedding_special_tokens[local_image_index]) + num_tokens = len(mm_embeddings[local_image_index]) + len( + mm_embedding_special_tokens[local_image_index]) # Create image token sequence image_tokens = [self.image_token] * num_tokens # Replace special tokens with actual decoded tokens - for offset, token_id in zip(mm_embedding_special_offsets[local_image_index], - mm_embedding_special_tokens[local_image_index]): + for offset, token_id in zip( + mm_embedding_special_offsets[local_image_index], + mm_embedding_special_tokens[local_image_index]): if offset < len(image_tokens): image_tokens[offset] = self.tokenizer.decode([token_id]) @@ -911,14 +928,18 @@ def attch_multimodal_embeddings(self, inputs: TextPrompt, multimodal_embedding: if sampling_params.truncate_prompt_tokens is not None: kwargs = dict(truncation=True, max_length=sampling_params.truncate_prompt_tokens) - text_inputs = self.tokenizer(processed_text, add_special_tokens=sampling_params.add_special_tokens, **kwargs) + text_inputs = self.tokenizer( + processed_text, + add_special_tokens=sampling_params.add_special_tokens, + **kwargs) token_ids = text_inputs.input_ids.squeeze() # Replace image token indices with out-of-vocabulary tokens token_ids[token_ids == self.image_token_index] = self.vocab_size + 1 # Concatenate all multimodal embeddings multimodal_data = {} - multimodal_data["multimodal_embedding"] = torch.cat(mm_embeddings, dim=0) + multimodal_data["multimodal_embedding"] = torch.cat(mm_embeddings, + dim=0) return token_ids.tolist(), {"multimodal_data": multimodal_data} @torch.inference_mode() diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 74a26dabe95..01d9bbb51d3 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -1,6 +1,6 @@ import copy import os -from typing import List, Optional, Tuple, Dict +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -195,7 +195,11 @@ def _postprocess(self, input_ids, mm_features): mm_features = mm_features.view(-1, mm_features.shape[-1]) return fused_input_ids, mm_features - def attch_multimodal_embeddings(self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[torch.Tensor]], sampling_params: SamplingParams) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + def attch_multimodal_embeddings( + self, inputs: TextPrompt, + multimodal_embedding: Dict[str, List[torch.Tensor]], + sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: """ Attach pre-processed multimodal embeddings into text token stream for Llama4 model. @@ -216,7 +220,8 @@ def attch_multimodal_embeddings(self, inputs: TextPrompt, multimodal_embedding: input_ids = self.tokenizer( text_prompt, return_tensors="pt").input_ids[0].to(self.device) - fused_input_ids, mm_features = self._postprocess(input_ids, multimodal_embedding['image']) + fused_input_ids, mm_features = self._postprocess( + input_ids, multimodal_embedding['image']) multimodal_data = {} multimodal_data["multimodal_embedding"] = mm_features return fused_input_ids.to(torch.int32).tolist(), { diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index b4b81fa8653..9ce0a8006bb 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -501,7 +501,8 @@ def _deduce_max_tokens(request: GenerationRequest, if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: # Convert back to tensor, as opposite to `to_handle` in `llm.generate_async` - request.multimodal_params.to_tensor("multimodal_data", key="multimodal_embedding") + request.multimodal_params.to_tensor( + "multimodal_data", key="multimodal_embedding") executor_request.py_multimodal_data = request.multimodal_params.multimodal_data if self._is_pytorch_backend and request.sampling_params.logits_processor: diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index f230e9954b3..afb16d0f396 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -234,7 +234,9 @@ def _to_device( f"MultimodalParams: Unsupported element '{element}' to move to device. " f"Supported elements: 'multimodal_data', 'multimodal_input'") - def to_handle(self, element: str, key: Optional[str] = "multimodal_embedding") -> None: + def to_handle(self, + element: str, + key: Optional[str] = "multimodal_embedding") -> None: """Convert multimodal data to tensor handle. Converts torch.Tensor objects to SharedTensorContainer handles (serializable dictionaries) @@ -266,7 +268,8 @@ def _to_tensor_handle(data): elif isinstance(v, list): for i, item in enumerate(v): if isinstance(item, torch.Tensor): - handle = SharedTensorContainer.from_tensor(item).dump_to_dict() + handle = SharedTensorContainer.from_tensor( + item).dump_to_dict() v[i] = handle if element == "multimodal_data": @@ -276,23 +279,30 @@ def _to_tensor_handle(data): _to_tensor_handle(self.multimodal_data) else: if key not in self.multimodal_data: - raise ValueError(f"Key '{key}' not found in multimodal_data") + raise ValueError( + f"Key '{key}' not found in multimodal_data") value = self.multimodal_data[key] if isinstance(value, torch.Tensor): - handle = SharedTensorContainer.from_tensor(value).dump_to_dict() + handle = SharedTensorContainer.from_tensor( + value).dump_to_dict() self.multimodal_data[key] = handle elif isinstance(value, dict): _to_tensor_handle(value) else: - raise ValueError(f"Unsupported value type for multimodal_data: {type(value)}") + raise ValueError( + f"Unsupported value type for multimodal_data: {type(value)}" + ) elif element == "multimodal_input": # No-op for multimodal_input return else: - raise ValueError(f"Unsupported element '{element}' to convert to handle.") + raise ValueError( + f"Unsupported element '{element}' to convert to handle.") - def to_tensor(self, element: str, key: Optional[str] = "multimodal_embedding") -> None: + def to_tensor(self, + element: str, + key: Optional[str] = "multimodal_embedding") -> None: """Convert multimodal tensor handles back to tensors. This is the dual operation to to_handle. Converts SharedTensorContainer handles (serializable dictionaries) back to torch.Tensor objects @@ -318,20 +328,26 @@ def _to_tensor(data): if isinstance(v, dict) and 'method_key' in v: # This is a tensor handle (dict with method_key) try: - tensor = SharedTensorContainer.from_dict(v).get_local_view() + tensor = SharedTensorContainer.from_dict( + v).get_local_view() data[k] = tensor except Exception as e: - raise ValueError(f"Failed to convert handle to tensor for key '{k}': {e}") + raise ValueError( + f"Failed to convert handle to tensor for key '{k}': {e}" + ) elif isinstance(v, dict): _to_tensor(v) elif isinstance(v, list): for i, item in enumerate(v): if isinstance(item, dict) and 'method_key' in item: try: - tensor = SharedTensorContainer.from_dict(item).get_local_view() + tensor = SharedTensorContainer.from_dict( + item).get_local_view() v[i] = tensor except Exception as e: - raise ValueError(f"Failed to convert handle to tensor in list at index {i}: {e}") + raise ValueError( + f"Failed to convert handle to tensor in list at index {i}: {e}" + ) if element == "multimodal_data": if self.multimodal_data is None: @@ -341,25 +357,34 @@ def _to_tensor(data): _to_tensor(self.multimodal_data) else: if key not in self.multimodal_data: - raise ValueError(f"Key '{key}' not found in multimodal_data") + raise ValueError( + f"Key '{key}' not found in multimodal_data") value = self.multimodal_data[key] - if isinstance(value, dict) and 'method_key' in value: # This is a tensor handle + if isinstance( + value, dict + ) and 'method_key' in value: # This is a tensor handle try: - tensor = SharedTensorContainer.from_dict(value).get_local_view() + tensor = SharedTensorContainer.from_dict( + value).get_local_view() self.multimodal_data[key] = tensor except Exception as e: - raise ValueError(f"Failed to convert handle to tensor for key '{key}': {e}") + raise ValueError( + f"Failed to convert handle to tensor for key '{key}': {e}" + ) elif isinstance(value, dict): _to_tensor(value) else: - raise ValueError(f"Unsupported value type for multimodal_data: {type(value)}") + raise ValueError( + f"Unsupported value type for multimodal_data: {type(value)}" + ) elif element == "multimodal_input": # No-op for multimodal_input return else: - raise ValueError(f"Unsupported element '{element}' to convert to tensor.") + raise ValueError( + f"Unsupported element '{element}' to convert to tensor.") def strip_for_context(self) -> None: """Strip multimodal data for context processing. diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 621489466e0..6deaf52c3e4 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -480,13 +480,16 @@ def default_multimodal_input_loader( media: Union[List[str], List[List[str]]], image_data_format: str = "pt", num_frames: int = 8, - mm_embeddings: Optional[Union[List[torch.Tensor], List[Dict[str, Any]]]] = None, + mm_embeddings: Optional[Union[List[torch.Tensor], + List[Dict[str, Any]]]] = None, device: str = "cpu") -> List[dict[str, Union[str, torch.Tensor]]]: - def convert_to_conversation_message(prompt: str, media: Union[str, - List[str]], - modality: str, - mm_embedding: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> ConversationMessage: + def convert_to_conversation_message( + prompt: str, + media: Union[str, List[str]], + modality: str, + mm_embedding: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None + ) -> ConversationMessage: if isinstance(media, str): media = [media] if modality == "image": @@ -495,13 +498,17 @@ def convert_to_conversation_message(prompt: str, media: Union[str, if not isinstance(mm_embedding, list): mm_embedding = [mm_embedding] - mm_data = [{'modality': modality, 'mm_embedding_info': mm} for mm in mm_embedding] + mm_data = [{ + 'modality': modality, + 'mm_embedding_info': mm + } for mm in mm_embedding] else: mm_data = [ MultimodalData(modality=modality, - data=load_image(i, - format=image_data_format, - device=device)) for i in media + data=load_image(i, + format=image_data_format, + device=device)) + for i in media ] elif modality == "video": mm_data = [ @@ -563,17 +570,20 @@ def convert_to_conversation_message(prompt: str, media: Union[str, for prompt_idx, (prompt, media) in enumerate(zip(prompts, media)): if mm_embeddings is not None: mm_embedding = mm_embeddings[prompt_idx] - conv = convert_to_conversation_message(prompt, media, modality, mm_embedding) + conv = convert_to_conversation_message(prompt, media, modality, + mm_embedding) else: conv = convert_to_conversation_message(prompt, media, modality) mm_data_tracker = MultimodalDataTracker(model_type) for mdata in conv["media"]: # Check if mdata is a MultimodalData - if isinstance(mdata, dict) and "modality" in mdata and "data" in mdata: + if isinstance(mdata, + dict) and "modality" in mdata and "data" in mdata: mm_data_tracker.add_data(mdata["modality"], mdata["data"]) else: # Add embeddings to the tracker for placeholder handling - mm_data_tracker.add_data(mdata["modality"], mdata["mm_embedding_info"]) + mm_data_tracker.add_data(mdata["modality"], + mdata["mm_embedding_info"]) mm_placeholder_counts = mm_data_tracker.placeholder_counts() prompt = conv["content"] if mm_placeholder_counts: @@ -589,13 +599,17 @@ def convert_to_conversation_message(prompt: str, media: Union[str, if mm_embeddings is not None: inputs.append({ - "prompt": prompt, - "multi_modal_embeddings": mm_data_tracker.retrieve_all_sync() + "prompt": + prompt, + "multi_modal_embeddings": + mm_data_tracker.retrieve_all_sync() }) else: inputs.append({ - "prompt": prompt, - "multi_modal_data": mm_data_tracker.retrieve_all_sync() + "prompt": + prompt, + "multi_modal_data": + mm_data_tracker.retrieve_all_sync() }) return inputs diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index f3841361e97..86f6f8a694b 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -341,11 +341,10 @@ def generate_async( inputs = prompt_inputs(inputs) - if not inputs.get("prompt") and inputs.get( - "prompt_token_ids") and (inputs.get( - "multi_modal_data") or inputs.get( - "multi_modal_embeddings")) and not isinstance( - self.input_processor, DefaultInputProcessor): + if not inputs.get("prompt") and inputs.get("prompt_token_ids") and ( + inputs.get("multi_modal_data") + or inputs.get("multi_modal_embeddings")) and not isinstance( + self.input_processor, DefaultInputProcessor): # VLMs need to process/tokenize the prompt in their own way prompt = self.tokenizer.decode(inputs['prompt_token_ids']) inputs = TextPrompt( @@ -380,7 +379,8 @@ def generate_async( inputs, sampling_params) elif 'multi_modal_embeddings' in inputs: mm_embedding_info = inputs['multi_modal_embeddings'] - prompt_token_ids, extra_processed_inputs = self.input_processor.attch_multimodal_embeddings(inputs, mm_embedding_info, sampling_params) + prompt_token_ids, extra_processed_inputs = self.input_processor.attch_multimodal_embeddings( + inputs, mm_embedding_info, sampling_params) else: with nvtx_range_debug("input_processor"): prompt_token_ids, extra_processed_inputs = self.input_processor( @@ -395,7 +395,8 @@ def generate_async( multimodal_data=extra_processed_inputs.get( 'multimodal_data')) # Convert to shared tensor handle to reduce IPC overhead - multimodal_params.to_handle("multimodal_data", key="multimodal_embedding") + multimodal_params.to_handle("multimodal_data", + key="multimodal_embedding") # Only pass it if it has content if not multimodal_params.has_content(): multimodal_params = None From b5ba5b6f98b08e23ccc0fd43650ac8ea1ca1fe13 Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Mon, 21 Jul 2025 21:13:38 -0700 Subject: [PATCH 05/13] tokenizer returns tensor --- tensorrt_llm/_torch/models/modeling_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 8a898d738f3..1e8a5af367c 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -930,6 +930,7 @@ def attch_multimodal_embeddings( max_length=sampling_params.truncate_prompt_tokens) text_inputs = self.tokenizer( processed_text, + return_tensors="pt", add_special_tokens=sampling_params.add_special_tokens, **kwargs) token_ids = text_inputs.input_ids.squeeze() From 78a52784e38b05e4746b153db9c270ec5d1cbfa3 Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Mon, 21 Jul 2025 22:18:41 -0700 Subject: [PATCH 06/13] Add unit tests --- examples/pytorch/quickstart_multimodal.py | 115 ------------------ tensorrt_llm/inputs/multimodal.py | 8 +- .../multimodal/test_share_multiparams.py | 84 +++++++++++++ 3 files changed, 88 insertions(+), 119 deletions(-) delete mode 100644 examples/pytorch/quickstart_multimodal.py create mode 100644 tests/unittest/_torch/multimodal/test_share_multiparams.py diff --git a/examples/pytorch/quickstart_multimodal.py b/examples/pytorch/quickstart_multimodal.py deleted file mode 100644 index 75d0eba184b..00000000000 --- a/examples/pytorch/quickstart_multimodal.py +++ /dev/null @@ -1,115 +0,0 @@ -import argparse -import json -import os - -from quickstart_advanced import add_llm_args, setup_llm - -from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS, - default_multimodal_input_loader) - -example_images = [ - "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png", - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png", - "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg", -] -example_image_prompts = [ - "Describe the natural environment in the image.", - "Describe the object and the weather condition in the image.", - "Describe the traffic condition on the road in the image.", -] -example_videos = [ - "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4", - "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4", -] -example_video_prompts = [ - "Tell me what you see in the video briefly.", - "Describe the scene in the video briefly.", -] - - -def add_multimodal_args(parser): - parser.add_argument("--model_type", - type=str, - choices=ALL_SUPPORTED_MULTIMODAL_MODELS, - help="Model type.") - parser.add_argument("--modality", - type=str, - choices=["image", "video"], - default="image", - help="Media type.") - parser.add_argument("--media", - type=str, - nargs="+", - help="A single or a list of media filepaths / urls.") - parser.add_argument("--num_frames", - type=int, - default=8, - help="The number of video frames to be sampled.") - parser.add_argument("--image_format", - type=str, - choices=["pt", "pil"], - default="pt", - help="The format of the image.") - return parser - - -def parse_arguments(): - parser = argparse.ArgumentParser( - description="Multimodal models with the PyTorch workflow.") - parser = add_llm_args(parser) - parser = add_multimodal_args(parser) - args = parser.parse_args() - - args.disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite - if args.kv_cache_fraction is None: - args.kv_cache_fraction = 0.6 # lower the default kv cache fraction for multimodal - - return args - - -def main(): - args = parse_arguments() - # set prompts and media to example prompts and images if they are not provided - if args.prompt is None: - args.prompt = example_image_prompts if args.modality == "image" else example_video_prompts - if args.media is None: - args.media = example_images if args.modality == "image" else example_videos - - llm, sampling_params = setup_llm(args) - - image_format = args.image_format - if args.model_type is not None: - model_type = args.model_type - else: - model_type = json.load( - open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type'] - assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}" - - device = "cuda" - import torch - mm_embeds_1 = torch.load("scout_mm_embed_seashore_v2.pt") - mm_embeds_2 = torch.load("scout_mm_embed_inpaint_v2.pt") - mm_embeds_3 = torch.load("scout_mm_embed_highway_61_v2.pt") - mm_embeds = [mm_embeds_1, mm_embeds_2, mm_embeds_3] - - inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer, - model_dir=llm._hf_model_dir, - model_type=model_type, - modality=args.modality, - prompts=args.prompt, - media=args.media, - image_data_format=image_format, - num_frames=args.num_frames, - mm_embeddings=mm_embeds, - device=device) - - outputs = llm.generate(inputs, sampling_params) - - for i, output in enumerate(outputs): - prompt = args.prompt[i] - generated_text = output.outputs[0].text - print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -if __name__ == "__main__": - main() diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index afb16d0f396..6957bdecffd 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -236,7 +236,7 @@ def _to_device( def to_handle(self, element: str, - key: Optional[str] = "multimodal_embedding") -> None: + key: Optional[str] = None) -> None: """Convert multimodal data to tensor handle. Converts torch.Tensor objects to SharedTensorContainer handles (serializable dictionaries) @@ -245,7 +245,7 @@ def to_handle(self, Args: element: Element to convert ("multimodal_data" or "multimodal_input") key: Specific key to convert. If None, converts all tensor values in multimodal_data. - Defaults to "multimodal_embedding". + Defaults to None. Example: # Convert all tensors in multimodal_data to handles @@ -302,7 +302,7 @@ def _to_tensor_handle(data): def to_tensor(self, element: str, - key: Optional[str] = "multimodal_embedding") -> None: + key: Optional[str] = None) -> None: """Convert multimodal tensor handles back to tensors. This is the dual operation to to_handle. Converts SharedTensorContainer handles (serializable dictionaries) back to torch.Tensor objects @@ -311,7 +311,7 @@ def to_tensor(self, Args: element: Element to convert ("multimodal_data" or "multimodal_input") key: Specific key to convert. If None, converts all tensor handles in multimodal_data. - Defaults to "multimodal_embedding". + Defaults to None. Example: # Convert all handles back to tensors diff --git a/tests/unittest/_torch/multimodal/test_share_multiparams.py b/tests/unittest/_torch/multimodal/test_share_multiparams.py new file mode 100644 index 00000000000..f579b6ba15e --- /dev/null +++ b/tests/unittest/_torch/multimodal/test_share_multiparams.py @@ -0,0 +1,84 @@ +import unittest + +import torch + +from tensorrt_llm.inputs.multimodal import MultimodalParams, MultimodalInput + + +class TestMultimodalParamsHandleConversion(unittest.TestCase): + """Test cases for to_handle and to_tensor methods in MultimodalParams.""" + + def setUp(self): + """Set up test fixtures.""" + # Create sample cpu tensors for testing (shared cuda tensor using cudaIPC only works between processes) + self.mm_embedding = torch.randn(3, 4, 5) + self.mrope_config = { + "mrope_rotary_cos_sin": torch.randn(2, 3), + "mrope_position_deltas": torch.randn(5), + } + self.image = { + "pixel_values": torch.randn(1, 3, 224, 224), + "image_height": [224], + "image_width": [224], + } + # Create sample multimodal data structure + self.sample_multimodal_data = { + "multimodal_embedding": self.mm_embedding, + "mrope_config": self.mrope_config, + "image": self.image, + } + + def test_to_handle_none_multimodal_data(self): + """Test to_handle with None multimodal_data.""" + params = MultimodalParams() + params.multimodal_data = None + + params.to_handle("multimodal_data") + self.assertIsNone(params.multimodal_data) + params.multimodal_data = {} + params.to_handle("multimodal_data") + self.assertEqual(params.multimodal_data, {}) + + params = MultimodalParams() + multimodal_input = MultimodalInput( + multimodal_hashes=[[1,2,3,4,5,6,7,8]]*2, + multimodal_positions=[0, 10], + multimodal_lengths=[2, 2] + ) + params.multimodal_input = multimodal_input + params.to_handle("multimodal_input") + self.assertEqual(params.multimodal_input, multimodal_input) + + def test_to_tensor_basic_handle(self): + """Test converting a basic handle back to tensor.""" + params = MultimodalParams() + params.multimodal_data = { + "multimodal_embedding": self.mm_embedding + } + + # Convert to handle + params.to_handle("multimodal_data", key="multimodal_embedding") + # Convert back to tensor + params.to_tensor("multimodal_data", key="multimodal_embedding") + + result = params.multimodal_data["multimodal_embedding"] + self.assertIsInstance(result, torch.Tensor) + self.assertTrue(torch.allclose(result, self.mm_embedding)) + + def test_to_tensor_all_handles(self): + """Test that to_handle followed by to_tensor preserves data integrity.""" + params = MultimodalParams() + params.multimodal_data = self.sample_multimodal_data.copy() + + params.to_handle("multimodal_data", key=None) + params.to_tensor("multimodal_data", key=None) + + self.assertTrue(torch.allclose(params.multimodal_data["multimodal_embedding"], self.mm_embedding)) + self.assertTrue(torch.allclose(params.multimodal_data["mrope_config"]["mrope_rotary_cos_sin"], self.mrope_config["mrope_rotary_cos_sin"])) + self.assertTrue(torch.allclose(params.multimodal_data["mrope_config"]["mrope_position_deltas"], self.mrope_config["mrope_position_deltas"])) + self.assertTrue(torch.allclose(params.multimodal_data["image"]["pixel_values"], self.image["pixel_values"])) + self.assertEqual(params.multimodal_data["image"]["image_height"], self.image["image_height"]) + self.assertEqual(params.multimodal_data["image"]["image_width"], self.image["image_width"]) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From efdd13be7896b873f84ae83be12e8229894a0400 Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Mon, 21 Jul 2025 22:22:19 -0700 Subject: [PATCH 07/13] Format --- tensorrt_llm/inputs/multimodal.py | 8 +--- .../multimodal/test_share_multiparams.py | 46 +++++++++++-------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index 6957bdecffd..b675000e41a 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -234,9 +234,7 @@ def _to_device( f"MultimodalParams: Unsupported element '{element}' to move to device. " f"Supported elements: 'multimodal_data', 'multimodal_input'") - def to_handle(self, - element: str, - key: Optional[str] = None) -> None: + def to_handle(self, element: str, key: Optional[str] = None) -> None: """Convert multimodal data to tensor handle. Converts torch.Tensor objects to SharedTensorContainer handles (serializable dictionaries) @@ -300,9 +298,7 @@ def _to_tensor_handle(data): raise ValueError( f"Unsupported element '{element}' to convert to handle.") - def to_tensor(self, - element: str, - key: Optional[str] = None) -> None: + def to_tensor(self, element: str, key: Optional[str] = None) -> None: """Convert multimodal tensor handles back to tensors. This is the dual operation to to_handle. Converts SharedTensorContainer handles (serializable dictionaries) back to torch.Tensor objects diff --git a/tests/unittest/_torch/multimodal/test_share_multiparams.py b/tests/unittest/_torch/multimodal/test_share_multiparams.py index f579b6ba15e..d4ce40f6332 100644 --- a/tests/unittest/_torch/multimodal/test_share_multiparams.py +++ b/tests/unittest/_torch/multimodal/test_share_multiparams.py @@ -2,7 +2,7 @@ import torch -from tensorrt_llm.inputs.multimodal import MultimodalParams, MultimodalInput +from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams class TestMultimodalParamsHandleConversion(unittest.TestCase): @@ -41,10 +41,9 @@ def test_to_handle_none_multimodal_data(self): params = MultimodalParams() multimodal_input = MultimodalInput( - multimodal_hashes=[[1,2,3,4,5,6,7,8]]*2, + multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]] * 2, multimodal_positions=[0, 10], - multimodal_lengths=[2, 2] - ) + multimodal_lengths=[2, 2]) params.multimodal_input = multimodal_input params.to_handle("multimodal_input") self.assertEqual(params.multimodal_input, multimodal_input) @@ -52,15 +51,13 @@ def test_to_handle_none_multimodal_data(self): def test_to_tensor_basic_handle(self): """Test converting a basic handle back to tensor.""" params = MultimodalParams() - params.multimodal_data = { - "multimodal_embedding": self.mm_embedding - } - + params.multimodal_data = {"multimodal_embedding": self.mm_embedding} + # Convert to handle params.to_handle("multimodal_data", key="multimodal_embedding") # Convert back to tensor params.to_tensor("multimodal_data", key="multimodal_embedding") - + result = params.multimodal_data["multimodal_embedding"] self.assertIsInstance(result, torch.Tensor) self.assertTrue(torch.allclose(result, self.mm_embedding)) @@ -69,16 +66,29 @@ def test_to_tensor_all_handles(self): """Test that to_handle followed by to_tensor preserves data integrity.""" params = MultimodalParams() params.multimodal_data = self.sample_multimodal_data.copy() - + params.to_handle("multimodal_data", key=None) params.to_tensor("multimodal_data", key=None) - - self.assertTrue(torch.allclose(params.multimodal_data["multimodal_embedding"], self.mm_embedding)) - self.assertTrue(torch.allclose(params.multimodal_data["mrope_config"]["mrope_rotary_cos_sin"], self.mrope_config["mrope_rotary_cos_sin"])) - self.assertTrue(torch.allclose(params.multimodal_data["mrope_config"]["mrope_position_deltas"], self.mrope_config["mrope_position_deltas"])) - self.assertTrue(torch.allclose(params.multimodal_data["image"]["pixel_values"], self.image["pixel_values"])) - self.assertEqual(params.multimodal_data["image"]["image_height"], self.image["image_height"]) - self.assertEqual(params.multimodal_data["image"]["image_width"], self.image["image_width"]) + + self.assertTrue( + torch.allclose(params.multimodal_data["multimodal_embedding"], + self.mm_embedding)) + self.assertTrue( + torch.allclose( + params.multimodal_data["mrope_config"]["mrope_rotary_cos_sin"], + self.mrope_config["mrope_rotary_cos_sin"])) + self.assertTrue( + torch.allclose( + params.multimodal_data["mrope_config"]["mrope_position_deltas"], + self.mrope_config["mrope_position_deltas"])) + self.assertTrue( + torch.allclose(params.multimodal_data["image"]["pixel_values"], + self.image["pixel_values"])) + self.assertEqual(params.multimodal_data["image"]["image_height"], + self.image["image_height"]) + self.assertEqual(params.multimodal_data["image"]["image_width"], + self.image["image_width"]) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 23e7325c5f153881729c8c378de33b03355c3165 Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Tue, 22 Jul 2025 11:09:02 -0700 Subject: [PATCH 08/13] Avoid hard-code special image tokens --- tensorrt_llm/_torch/models/modeling_llama.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 1e8a5af367c..315fc37f3f4 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -807,12 +807,10 @@ def __init__(self, self.tokenizer = tokenizer self.vocab_size = model_config.text_config.vocab_size self.image_token_index = model_config.image_token_index - # TODO: use tokenizer to get these special tokens/values - self.fake_image_token = "<|image|>" # this must be the same as placeholder prompt - self.image_token = "<|patch|>" - self.image_token_start_index = 200080 - self.image_token_end_index = 200081 - + self.fake_image_token = self.processor.fake_image_token + self.image_token = self.processor.img_patch_token + self.image_token_start_index = self.model_config.boi_token_index + self.image_token_end_index = self.model_config.eoi_token_index self.encoder = nn.ModuleDict({ "vision_model": Llama4VisionModel(model_config.vision_config), From cd7393a1b4fdcfe96f37b08a0c69a05cb79a27df Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Thu, 24 Jul 2025 14:45:16 -0700 Subject: [PATCH 09/13] Address comment --- tensorrt_llm/_torch/models/modeling_llama.py | 2 +- tensorrt_llm/_torch/models/modeling_llava_next.py | 2 +- tensorrt_llm/llmapi/llm.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 315fc37f3f4..fed075944f0 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -819,7 +819,7 @@ def __init__(self, }).cuda() load_sharded_checkpoint(self.encoder, model_path, strict=False) - def attch_multimodal_embeddings( + def attach_multimodal_embeddings( self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[Dict[str, Any]]], diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 01d9bbb51d3..41dccfeac17 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -195,7 +195,7 @@ def _postprocess(self, input_ids, mm_features): mm_features = mm_features.view(-1, mm_features.shape[-1]) return fused_input_ids, mm_features - def attch_multimodal_embeddings( + def attach_multimodal_embeddings( self, inputs: TextPrompt, multimodal_embedding: Dict[str, List[torch.Tensor]], sampling_params: SamplingParams diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 86f6f8a694b..a74318b920c 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -379,7 +379,7 @@ def generate_async( inputs, sampling_params) elif 'multi_modal_embeddings' in inputs: mm_embedding_info = inputs['multi_modal_embeddings'] - prompt_token_ids, extra_processed_inputs = self.input_processor.attch_multimodal_embeddings( + prompt_token_ids, extra_processed_inputs = self.input_processor.attach_multimodal_embeddings( inputs, mm_embedding_info, sampling_params) else: with nvtx_range_debug("input_processor"): From 69c1fc1e04805aaf04c76cd6c9f9404bfd956f56 Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Thu, 24 Jul 2025 17:03:56 -0700 Subject: [PATCH 10/13] Update to cover corner cases --- tensorrt_llm/_torch/models/modeling_llama.py | 4 ++++ tensorrt_llm/_torch/models/modeling_llava_next.py | 2 +- tensorrt_llm/executor/worker.py | 8 ++++++++ tensorrt_llm/inputs/multimodal.py | 6 ++---- tensorrt_llm/llmapi/llm.py | 1 + 5 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index fed075944f0..f89f6f6bbe9 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -913,6 +913,10 @@ def attach_multimodal_embeddings( for offset, token_id in zip( mm_embedding_special_offsets[local_image_index], mm_embedding_special_tokens[local_image_index]): + if offset < 0 or offset >= len(image_tokens): + raise ValueError( + f"Image special token offset {offset} is out of range with the total image tokens length {len(image_tokens)}" + ) if offset < len(image_tokens): image_tokens[offset] = self.tokenizer.decode([token_id]) diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 41dccfeac17..0370deb71e4 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -201,7 +201,7 @@ def attach_multimodal_embeddings( sampling_params: SamplingParams ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: """ - Attach pre-processed multimodal embeddings into text token stream for Llama4 model. + Attach pre-processed multimodal embeddings into text token stream for LlavaNext model. This method skips vision processing and works with externally provided embeddings. It replaces/expands image placeholders in the text with appropriate tokens and prepares diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 9ce0a8006bb..577325bfcf8 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -501,8 +501,16 @@ def _deduce_max_tokens(request: GenerationRequest, if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: # Convert back to tensor, as opposite to `to_handle` in `llm.generate_async` + # for values with non-selected keys, it's no-op request.multimodal_params.to_tensor( "multimodal_data", key="multimodal_embedding") + embedding = request.multimodal_params.multimodal_data.get( + "multimodal_embedding") + if embedding is not None and embedding.is_cuda: + # make sure the embedding resides on the local device + request.multimodal_params.multimodal_data[ + "multimodal_embedding"] = embedding.to("cuda") + executor_request.py_multimodal_data = request.multimodal_params.multimodal_data if self._is_pytorch_backend and request.sampling_params.logits_processor: diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index b675000e41a..93689065879 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -277,8 +277,7 @@ def _to_tensor_handle(data): _to_tensor_handle(self.multimodal_data) else: if key not in self.multimodal_data: - raise ValueError( - f"Key '{key}' not found in multimodal_data") + return # no-op if key not found value = self.multimodal_data[key] if isinstance(value, torch.Tensor): @@ -353,8 +352,7 @@ def _to_tensor(data): _to_tensor(self.multimodal_data) else: if key not in self.multimodal_data: - raise ValueError( - f"Key '{key}' not found in multimodal_data") + return # no-op if key not found value = self.multimodal_data[key] if isinstance( diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index a74318b920c..21d02c79caa 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -395,6 +395,7 @@ def generate_async( multimodal_data=extra_processed_inputs.get( 'multimodal_data')) # Convert to shared tensor handle to reduce IPC overhead + # for values with non-selected keys, it's no-op multimodal_params.to_handle("multimodal_data", key="multimodal_embedding") # Only pass it if it has content From a20625c37b36f7f7e586b84cc54376e1c14acd62 Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Sat, 26 Jul 2025 22:31:24 -0700 Subject: [PATCH 11/13] Update to only accept either image_url or image_emb --- .../_torch/models/modeling_llava_next.py | 3 +- tensorrt_llm/inputs/utils.py | 45 +++++++++++-------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 0370deb71e4..684673403a6 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -220,8 +220,9 @@ def attach_multimodal_embeddings( input_ids = self.tokenizer( text_prompt, return_tensors="pt").input_ids[0].to(self.device) + mm_features = torch.stack(multimodal_embedding['image']) fused_input_ids, mm_features = self._postprocess( - input_ids, multimodal_embedding['image']) + input_ids, mm_features) multimodal_data = {} multimodal_data["multimodal_embedding"] = mm_features return fused_input_ids.to(torch.int32).tolist(), { diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 6deaf52c3e4..d1c933a11c5 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -477,31 +477,31 @@ def default_multimodal_input_loader( model_type: str, modality: str, prompts: List[str], - media: Union[List[str], List[List[str]]], + media: Optional[Union[List[str], List[List[str]]]] = None, image_data_format: str = "pt", num_frames: int = 8, mm_embeddings: Optional[Union[List[torch.Tensor], - List[Dict[str, Any]]]] = None, + List[List[torch.Tensor]]]] = None, device: str = "cpu") -> List[dict[str, Union[str, torch.Tensor]]]: def convert_to_conversation_message( prompt: str, - media: Union[str, List[str]], + media: Union[Any, List[Any]], modality: str, - mm_embedding: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None + is_embedding: bool = False, ) -> ConversationMessage: if isinstance(media, str): media = [media] if modality == "image": - if mm_embedding is not None: + if is_embedding: # each mm_embedding corresponds to each image placeholder - if not isinstance(mm_embedding, list): - mm_embedding = [mm_embedding] + if not isinstance(media, list): + media = [media] mm_data = [{ 'modality': modality, 'mm_embedding_info': mm - } for mm in mm_embedding] + } for mm in media] else: mm_data = [ MultimodalData(modality=modality, @@ -511,6 +511,8 @@ def convert_to_conversation_message( for i in media ] elif modality == "video": + if is_embedding: + raise ValueError("External embedding is not supported for video modality yet.") mm_data = [ MultimodalData(modality=modality, data=load_video(i, @@ -519,11 +521,15 @@ def convert_to_conversation_message( device=device)) for i in media ] elif modality == "audio": + if is_embedding: + raise ValueError("External embedding is not supported for audio modality yet.") mm_data = [ MultimodalData(modality=modality, data=load_audio(i, device=device)) for i in media ] elif modality == "image_audio": + if is_embedding: + raise ValueError("External embedding is not supported for image_audio modality yet.") # Use different load_xxx functions to match the modality. mm_data = [] for m in media: @@ -550,12 +556,17 @@ def convert_to_conversation_message( raise ValueError(f"Unknown modality: {modality}") return ConversationMessage(role="user", content=prompt, media=mm_data) - if len(media) > len(prompts) and len(prompts) == 1: + assert media is not None or mm_embeddings is not None, "Either media or mm_embeddings must be provided." + assert media is None or mm_embeddings is None, "Either media or mm_embeddings must be provided, not both." + media_or_embeddings = media if media is not None else mm_embeddings + is_embedding = mm_embeddings is not None + + if len(media_or_embeddings) > len(prompts) and len(prompts) == 1: # 1 prompt + N media assert not isinstance( - media[0], list) # media cannot be a list of lists in this case - media = [media] - assert len(media) == len(prompts) + media_or_embeddings[0], list) # media cannot be a list of lists in this case + media_or_embeddings = [media_or_embeddings] + assert len(media_or_embeddings) == len(prompts) if tokenizer is None and model_type not in HF_CHAT_TEMPLATE_EXCEPTIONS: tokenizer = ModelLoader.load_hf_tokenizer(model_dir, use_fast=True) @@ -567,13 +578,9 @@ def convert_to_conversation_message( trust_remote_code=True) inputs = [] - for prompt_idx, (prompt, media) in enumerate(zip(prompts, media)): - if mm_embeddings is not None: - mm_embedding = mm_embeddings[prompt_idx] - conv = convert_to_conversation_message(prompt, media, modality, - mm_embedding) - else: - conv = convert_to_conversation_message(prompt, media, modality) + for prompt_idx, (prompt, media) in enumerate(zip(prompts, media_or_embeddings)): + conv = convert_to_conversation_message(prompt, media, modality, + is_embedding) mm_data_tracker = MultimodalDataTracker(model_type) for mdata in conv["media"]: # Check if mdata is a MultimodalData From d9c29f9712af801b7e6645d37482b16ccf5eb268 Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Sat, 26 Jul 2025 22:39:44 -0700 Subject: [PATCH 12/13] Format --- .../_torch/models/modeling_llava_next.py | 3 +-- tensorrt_llm/inputs/utils.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 684673403a6..6cd2f587515 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -221,8 +221,7 @@ def attach_multimodal_embeddings( input_ids = self.tokenizer( text_prompt, return_tensors="pt").input_ids[0].to(self.device) mm_features = torch.stack(multimodal_embedding['image']) - fused_input_ids, mm_features = self._postprocess( - input_ids, mm_features) + fused_input_ids, mm_features = self._postprocess(input_ids, mm_features) multimodal_data = {} multimodal_data["multimodal_embedding"] = mm_features return fused_input_ids.to(torch.int32).tolist(), { diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index d1c933a11c5..8064aca3e3a 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -512,7 +512,9 @@ def convert_to_conversation_message( ] elif modality == "video": if is_embedding: - raise ValueError("External embedding is not supported for video modality yet.") + raise ValueError( + "External embedding is not supported for video modality yet." + ) mm_data = [ MultimodalData(modality=modality, data=load_video(i, @@ -522,14 +524,18 @@ def convert_to_conversation_message( ] elif modality == "audio": if is_embedding: - raise ValueError("External embedding is not supported for audio modality yet.") + raise ValueError( + "External embedding is not supported for audio modality yet." + ) mm_data = [ MultimodalData(modality=modality, data=load_audio(i, device=device)) for i in media ] elif modality == "image_audio": if is_embedding: - raise ValueError("External embedding is not supported for image_audio modality yet.") + raise ValueError( + "External embedding is not supported for image_audio modality yet." + ) # Use different load_xxx functions to match the modality. mm_data = [] for m in media: @@ -564,7 +570,8 @@ def convert_to_conversation_message( if len(media_or_embeddings) > len(prompts) and len(prompts) == 1: # 1 prompt + N media assert not isinstance( - media_or_embeddings[0], list) # media cannot be a list of lists in this case + media_or_embeddings[0], + list) # media cannot be a list of lists in this case media_or_embeddings = [media_or_embeddings] assert len(media_or_embeddings) == len(prompts) @@ -578,7 +585,8 @@ def convert_to_conversation_message( trust_remote_code=True) inputs = [] - for prompt_idx, (prompt, media) in enumerate(zip(prompts, media_or_embeddings)): + for prompt_idx, (prompt, + media) in enumerate(zip(prompts, media_or_embeddings)): conv = convert_to_conversation_message(prompt, media, modality, is_embedding) mm_data_tracker = MultimodalDataTracker(model_type) From 0a83b39222595a73ffdbfb124cca1ad45df07990 Mon Sep 17 00:00:00 2001 From: "Chang Liu (Enterprise Products)" <9713593+chang-l@users.noreply.github.com> Date: Mon, 28 Jul 2025 10:01:39 -0700 Subject: [PATCH 13/13] Address comment --- tensorrt_llm/_torch/models/modeling_llama.py | 4 +++- tensorrt_llm/_torch/models/modeling_llava_next.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index f89f6f6bbe9..b84910e228e 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -849,7 +849,9 @@ def attach_multimodal_embeddings( raise ValueError("multimodal_embedding must be a dictionary") if 'image' not in multimodal_embedding: - raise ValueError("Only image modality is supported for now") + raise ValueError( + "Only image modality is supported for external multimodal embedding" + ) mm_embedding_info = multimodal_embedding['image'] if not mm_embedding_info or not isinstance(mm_embedding_info[0], dict): diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 6cd2f587515..cf5e0778822 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -216,7 +216,16 @@ def attach_multimodal_embeddings( - extra_processed_inputs: Optional dictionary containing multimodal embeddings """ text_prompt = inputs.get("prompt") - assert 'image' in multimodal_embedding + if not text_prompt: + raise ValueError("Text prompt is required but not provided") + + if not isinstance(multimodal_embedding, dict): + raise ValueError("multimodal_embedding must be a dictionary") + + if 'image' not in multimodal_embedding: + raise ValueError( + "Only image modality is supported for external multimodal embedding" + ) input_ids = self.tokenizer( text_prompt, return_tensors="pt").input_ids[0].to(self.device)