diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e6044133f93..f98ba385ca7 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1227,11 +1227,13 @@ def _prepare_tp_inputs( multimodal_params = MultimodalParams( multimodal_data=request.py_multimodal_data, multimodal_runtime=py_multimodal_runtime) - multimodal_params.to_device("multimodal_data", - "cuda", - pin_memory=True) if multimodal_params.has_content(): + multimodal_params.to_device("multimodal_data", + "cuda", + pin_memory=True) + #re-assign the multimodal_data to the request after to_device for generation requests + request.py_multimodal_data = multimodal_params.multimodal_data multimodal_params_list.append(multimodal_params) request.py_batch_idx = request.py_seq_slot @@ -1265,10 +1267,12 @@ def _prepare_tp_inputs( multimodal_params = MultimodalParams( multimodal_data=request.py_multimodal_data) multimodal_params.strip_for_generation() - multimodal_params.to_device("multimodal_data", - "cuda", - pin_memory=True) if multimodal_params.has_content(): + multimodal_params.to_device("multimodal_data", + "cuda", + pin_memory=True) + # re-assign the multimodal_data to the request after strip_for_generation for another generation request, + request.py_multimodal_data = multimodal_params.multimodal_data multimodal_params_list.append(multimodal_params) extend_requests += extend_dummy_requests diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index db8d84fcc89..2d8c1c0255d 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -486,7 +486,7 @@ def _deduce_max_tokens(request: GenerationRequest, lora_config=lora_config, prompt_tuning_config=prompt_tuning_config, multimodal_input=multimodal_input, - #NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`. + # NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`. multimodal_embedding=None, mrope_config=None, logits_post_processor_name=( @@ -502,17 +502,8 @@ def _deduce_max_tokens(request: GenerationRequest, if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: - # Convert back to tensor, as opposite to `to_handle` in `llm.generate_async` - # for values with non-selected keys, it's no-op - request.multimodal_params.to_tensor( - "multimodal_data", key="multimodal_embedding") - embedding = request.multimodal_params.multimodal_data.get( - "multimodal_embedding") - if embedding is not None and embedding.is_cuda: - # make sure the embedding resides on the local device - request.multimodal_params.multimodal_data[ - "multimodal_embedding"] = embedding.to("cuda") - + # NOTE: Deserialize SharedTensor handle to actual tensor + request.multimodal_params.to_tensor("multimodal_data") executor_request.py_multimodal_data = request.multimodal_params.multimodal_data if self._is_pytorch_backend and request.sampling_params.logits_processor: diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index 93689065879..77eec0ff6a6 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -159,6 +159,10 @@ class MultimodalParams: multimodal_input: Multimodal input data with hashing information. multimodal_data: Processed multimodal data containing embeddings, configurations, and modality-specific data organized by type. + multimodal_runtime: Runtime data for tracking multimodal token caching and reuse + during KV cache scenarios. Contains information about cached + tokens, multimodal token positions, and lengths for efficient + processing during inference. Structure of multimodal_data: { @@ -190,237 +194,216 @@ def __post_init__(self): if self.multimodal_data is None: self.multimodal_data = {} - def to_device(self, - element: str, - device: str, - pin_memory: bool = False) -> None: - """Move specified multimodal data element to target device. + def _is_shared_tensor_dict(self, obj: Any) -> bool: + """Check if an object is a shared tensor dictionary. Args: - element: Element to move ("multimodal_data" or "multimodal_input") - device: Target device (e.g., "cuda", "cpu") - pin_memory: Whether to pin memory for faster transfers + obj: Object to check + + Returns: + True if the object is a shared tensor dictionary, False otherwise """ + if not isinstance(obj, dict): + return False - def _to_device( - input_tensor: Union[torch.Tensor, List, dict, None], - pin_memory: bool = False, - ) -> Union[torch.Tensor, List, dict, None]: - if input_tensor is None: - return None - elif isinstance(input_tensor, list): - return [_to_device(item, pin_memory) for item in input_tensor] - elif isinstance(input_tensor, dict): - return { - key: _to_device(value, pin_memory) - for key, value in input_tensor.items() - } - elif isinstance(input_tensor, torch.Tensor): - if pin_memory and input_tensor.device.type == 'cpu': - return input_tensor.pin_memory().to(device, - non_blocking=True) - else: - return input_tensor.to(device, non_blocking=True) - else: - return input_tensor + # Check for required keys that uniquely identify a shared tensor dict + required_keys = {'method_key'} + if not required_keys.issubset(obj.keys()): + return False - if element == "multimodal_data": - self.multimodal_data = _to_device(self.multimodal_data, pin_memory) - elif element == "multimodal_input": - self.multimodal_input = _to_device(self.multimodal_input, - pin_memory) - else: - print( - f"MultimodalParams: Unsupported element '{element}' to move to device. " - f"Supported elements: 'multimodal_data', 'multimodal_input'") + # Additional validation based on method_key + method_key = obj.get('method_key') - def to_handle(self, element: str, key: Optional[str] = None) -> None: - """Convert multimodal data to tensor handle. + # Import here to avoid circular imports + from tensorrt_llm._torch.shared_tensor import \ + _SharedTensorRebuildMethodRegistry - Converts torch.Tensor objects to SharedTensorContainer handles (serializable dictionaries) - for efficient IPC. This function is a in-place operation. + if method_key == _SharedTensorRebuildMethodRegistry.REBUILD_CUDA: + cuda_keys = {'tensor_size', 'storage_handle', 'storage_device'} + return cuda_keys.issubset(obj.keys()) + elif method_key == _SharedTensorRebuildMethodRegistry.REBUILD_CPU: + cpu_keys = {'tensor_size', 'storage_handle', 'manager_handle'} + return cpu_keys.issubset(obj.keys()) - Args: - element: Element to convert ("multimodal_data" or "multimodal_input") - key: Specific key to convert. If None, converts all tensor values in multimodal_data. - Defaults to None. + return False + + def _apply_tensor_operation( + self, input_data: Union[torch.Tensor, List, dict, None], + operation: str, **kwargs) -> Union[torch.Tensor, List, dict, None]: + """Apply tensor operations recursively to nested data structures. - Example: - # Convert all tensors in multimodal_data to handles - params.to_handle("multimodal_data", key=None) + This method handles three types of operations: + - "to_handle": Convert tensors to shared tensor dictionaries + - "to_tensor": Convert shared tensor dictionaries back to tensors + - "to_device": Move tensors to specified device - # Convert only multimodal_embedding section tensors to handles - params.to_handle("multimodal_data", key="multimodal_embedding") + Args: + input_data: Input data structure (tensor, list, dict, or None) + operation: Operation to apply + **kwargs: Additional arguments for the operation + + Returns: + Transformed data structure """ - # Lazy import to avoid circular dependency - from tensorrt_llm._torch.shared_tensor import SharedTensorContainer - - def _to_tensor_handle(data): - for k, v in data.items(): - if isinstance(v, torch.Tensor): - # Convert tensor to handle - handle = SharedTensorContainer.from_tensor(v).dump_to_dict() - data[k] = handle - elif isinstance(v, dict): - _to_tensor_handle(v) - elif isinstance(v, list): - for i, item in enumerate(v): - if isinstance(item, torch.Tensor): - handle = SharedTensorContainer.from_tensor( - item).dump_to_dict() - v[i] = handle - - if element == "multimodal_data": - if self.multimodal_data is None: - return - if key is None: - _to_tensor_handle(self.multimodal_data) + # Handle None case + if input_data is None: + return None + + # Handle list case - recursively process each element + if isinstance(input_data, list): + return [ + self._apply_tensor_operation(item, operation, **kwargs) + for item in input_data + ] + + # Handle dictionary case + if isinstance(input_data, dict): + if operation == "to_tensor" and self._is_shared_tensor_dict( + input_data): + # Convert shared tensor dict back to tensor + try: + # Import here to avoid circular imports + from tensorrt_llm._torch.shared_tensor import \ + SharedTensorContainer + + return SharedTensorContainer.from_dict( + input_data).get_local_view() + except Exception as e: + raise RuntimeError( + f"Failed to restore tensor from shared tensor dict: {e}" + ) else: - if key not in self.multimodal_data: - return # no-op if key not found - - value = self.multimodal_data[key] - if isinstance(value, torch.Tensor): - handle = SharedTensorContainer.from_tensor( - value).dump_to_dict() - self.multimodal_data[key] = handle - elif isinstance(value, dict): - _to_tensor_handle(value) - else: + # Regular dictionary - recursively process values + return { + key: self._apply_tensor_operation(value, operation, + **kwargs) + for key, value in input_data.items() + } + + # Handle tensor case + if isinstance(input_data, torch.Tensor): + if operation == "to_handle": + try: + # Import here to avoid circular imports + from tensorrt_llm._torch.shared_tensor import \ + SharedTensorContainer + return SharedTensorContainer.from_tensor( + input_data).dump_to_dict() + except Exception as e: + raise RuntimeError( + f"Failed to convert tensor to shared tensor: {e}") + elif operation == "to_device": + device = kwargs.get('device') + if device is None: raise ValueError( - f"Unsupported value type for multimodal_data: {type(value)}" - ) - elif element == "multimodal_input": - # No-op for multimodal_input - return - else: - raise ValueError( - f"Unsupported element '{element}' to convert to handle.") + "Device must be specified for 'to_device' operation") - def to_tensor(self, element: str, key: Optional[str] = None) -> None: - """Convert multimodal tensor handles back to tensors. This is the dual operation to to_handle. + pin_memory = kwargs.get('pin_memory', False) + try: + if pin_memory and input_data.device.type == 'cpu': + return input_data.pin_memory().to(device, + non_blocking=True) + else: + return input_data.to(device, non_blocking=True) + except Exception as e: + raise RuntimeError( + f"Failed to move tensor to device {device}: {e}") - Converts SharedTensorContainer handles (serializable dictionaries) back to torch.Tensor objects - for local computation. This function performs in-place modifications to the multimodal_data. + # For any other type, return as-is + return input_data - Args: - element: Element to convert ("multimodal_data" or "multimodal_input") - key: Specific key to convert. If None, converts all tensor handles in multimodal_data. - Defaults to None. + def to_handle(self, element: str) -> None: + """Move specified multimodal data element to shared tensor. - Example: - # Convert all handles back to tensors - params.to_tensor("multimodal_data", key=None) + Args: + element: Element to move (only "multimodal_data" is supported) - # Convert only multimodal_embedding section handles back to tensors - params.to_tensor("multimodal_data", key="multimodal_embedding") + Raises: + ValueError: If element is not "multimodal_data" + RuntimeError: If tensor conversion fails """ - # Lazy import to avoid circular dependency - from tensorrt_llm._torch.shared_tensor import SharedTensorContainer - - def _to_tensor(data): - for k, v in data.items(): - if isinstance(v, dict) and 'method_key' in v: - # This is a tensor handle (dict with method_key) - try: - tensor = SharedTensorContainer.from_dict( - v).get_local_view() - data[k] = tensor - except Exception as e: - raise ValueError( - f"Failed to convert handle to tensor for key '{k}': {e}" - ) - elif isinstance(v, dict): - _to_tensor(v) - elif isinstance(v, list): - for i, item in enumerate(v): - if isinstance(item, dict) and 'method_key' in item: - try: - tensor = SharedTensorContainer.from_dict( - item).get_local_view() - v[i] = tensor - except Exception as e: - raise ValueError( - f"Failed to convert handle to tensor in list at index {i}: {e}" - ) - - if element == "multimodal_data": - if self.multimodal_data is None: - return - - if key is None: - _to_tensor(self.multimodal_data) - else: - if key not in self.multimodal_data: - return # no-op if key not found - - value = self.multimodal_data[key] - if isinstance( - value, dict - ) and 'method_key' in value: # This is a tensor handle - try: - tensor = SharedTensorContainer.from_dict( - value).get_local_view() - self.multimodal_data[key] = tensor - except Exception as e: - raise ValueError( - f"Failed to convert handle to tensor for key '{key}': {e}" - ) - elif isinstance(value, dict): - _to_tensor(value) - else: - raise ValueError( - f"Unsupported value type for multimodal_data: {type(value)}" - ) + if element != "multimodal_data": + raise ValueError( + f"Unsupported element '{element}'. Only 'multimodal_data' is supported." + ) - elif element == "multimodal_input": - # No-op for multimodal_input - return - else: + data = getattr(self, element) + if data is None: + return # Nothing to convert + + transformed_data = self._apply_tensor_operation(data, "to_handle") + setattr(self, element, transformed_data) + + def to_tensor(self, element: str) -> None: + """Move specified multimodal data element from shared tensor. + + Args: + element: Element to restore (only "multimodal_data" is supported) + + Raises: + ValueError: If element is not "multimodal_data" + RuntimeError: If tensor restoration fails + """ + if element != "multimodal_data": raise ValueError( - f"Unsupported element '{element}' to convert to tensor.") + f"Unsupported element '{element}'. Only 'multimodal_data' is supported." + ) + + data = getattr(self, element) + if data is None: + return # Nothing to restore - def strip_for_context(self) -> None: - """Strip multimodal data for context processing. + restored_data = self._apply_tensor_operation(data, "to_tensor") + setattr(self, element, restored_data) - Removes only mrope_position_deltas while keeping all other multimodal data - (embeddings, images, etc.) needed for context phase processing. + def to_device(self, + element: str, + device: str, + pin_memory: bool = False) -> None: + """Move specified multimodal data element to target device. + + Args: + element: Element to move (only "multimodal_data" is supported) + device: Target device (e.g., "cuda", "cpu") + pin_memory: Whether to pin memory for faster transfers + + Raises: + ValueError: If element is not "multimodal_data" or device is invalid + RuntimeError: If device transfer fails """ - if not (self.multimodal_data - and 'mrope_config' in self.multimodal_data): - return + if element != "multimodal_data": + raise ValueError( + f"Unsupported element '{element}'. Only 'multimodal_data' is supported." + ) - mrope_config = self.multimodal_data['mrope_config'] - if 'mrope_position_deltas' in mrope_config: - del mrope_config['mrope_position_deltas'] + data = getattr(self, element) + if data is None: + return # Nothing to move - # Clean up empty mrope_config - if not mrope_config: - del self.multimodal_data['mrope_config'] + transformed_data = self._apply_tensor_operation(data, + "to_device", + device=device, + pin_memory=pin_memory) + setattr(self, element, transformed_data) def strip_for_generation(self) -> None: """Strip multimodal data for generation processing. - Keeps only mrope_position_deltas and removes all other multimodal data + Keeps only mrope_config and removes all other multimodal data (embeddings, images, etc.) as they're not needed during generation. """ if not self.multimodal_data: return - # Extract mrope_position_deltas before clearing - mrope_position_deltas = None + # Extract mrope_config before clearing + mrope_config = None if 'mrope_config' in self.multimodal_data: mrope_config = self.multimodal_data['mrope_config'] - if isinstance(mrope_config, - dict) and 'mrope_position_deltas' in mrope_config: - mrope_position_deltas = mrope_config['mrope_position_deltas'] - # Clear all data and restore only position deltas if they exist + # Clear all data and restore only mrope_config if it exists self.multimodal_data = {} - if mrope_position_deltas is not None: - self.multimodal_data['mrope_config'] = { - 'mrope_position_deltas': mrope_position_deltas - } + if mrope_config is not None: + self.multimodal_data['mrope_config'] = mrope_config def has_content(self) -> bool: """Check if this object contains any multimodal data.""" diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 12bb079eaf5..78eb0304b5f 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -403,13 +403,12 @@ def generate_async( 'multimodal_input'), multimodal_data=extra_processed_inputs.get( 'multimodal_data')) - # Convert to shared tensor handle to reduce IPC overhead - # for values with non-selected keys, it's no-op - multimodal_params.to_handle("multimodal_data", - key="multimodal_embedding") # Only pass it if it has content if not multimodal_params.has_content(): multimodal_params = None + else: + # Convert to shared tensor handle to reduce IPC overhead + multimodal_params.to_handle("multimodal_data") else: raise TypeError( f"The inputs must be type str or list of int, but got {type(inputs)}" diff --git a/tests/unittest/_torch/multimodal/test_share_multiparams.py b/tests/unittest/_torch/multimodal/test_share_multiparams.py index d4ce40f6332..343c2f53721 100644 --- a/tests/unittest/_torch/multimodal/test_share_multiparams.py +++ b/tests/unittest/_torch/multimodal/test_share_multiparams.py @@ -39,14 +39,20 @@ def test_to_handle_none_multimodal_data(self): params.to_handle("multimodal_data") self.assertEqual(params.multimodal_data, {}) + def test_to_handle_unsupported_element(self): + """Test to_handle raises ValueError for unsupported elements.""" params = MultimodalParams() multimodal_input = MultimodalInput( multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]] * 2, multimodal_positions=[0, 10], multimodal_lengths=[2, 2]) params.multimodal_input = multimodal_input - params.to_handle("multimodal_input") - self.assertEqual(params.multimodal_input, multimodal_input) + + with self.assertRaises(ValueError) as context: + params.to_handle("multimodal_input") + + self.assertIn("Unsupported element 'multimodal_input'", + str(context.exception)) def test_to_tensor_basic_handle(self): """Test converting a basic handle back to tensor.""" @@ -54,9 +60,9 @@ def test_to_tensor_basic_handle(self): params.multimodal_data = {"multimodal_embedding": self.mm_embedding} # Convert to handle - params.to_handle("multimodal_data", key="multimodal_embedding") + params.to_handle("multimodal_data") # Convert back to tensor - params.to_tensor("multimodal_data", key="multimodal_embedding") + params.to_tensor("multimodal_data") result = params.multimodal_data["multimodal_embedding"] self.assertIsInstance(result, torch.Tensor) @@ -67,8 +73,8 @@ def test_to_tensor_all_handles(self): params = MultimodalParams() params.multimodal_data = self.sample_multimodal_data.copy() - params.to_handle("multimodal_data", key=None) - params.to_tensor("multimodal_data", key=None) + params.to_handle("multimodal_data") + params.to_tensor("multimodal_data") self.assertTrue( torch.allclose(params.multimodal_data["multimodal_embedding"], @@ -90,5 +96,56 @@ def test_to_tensor_all_handles(self): self.image["image_width"]) +class TestMultimodalParamsDeviceTransfer(unittest.TestCase): + """Test cases for to_device method in MultimodalParams.""" + + def setUp(self): + """Set up test fixtures.""" + self.mm_embedding = torch.randn(3, 4, 5) + self.mrope_config = { + "mrope_rotary_cos_sin": torch.randn(2, 3), + "mrope_position_deltas": torch.randn(5), + } + self.image = { + "pixel_values": torch.randn(1, 3, 224, 224), + "image_height": [224], + "image_width": [224], + } + self.sample_multimodal_data = { + "multimodal_embedding": self.mm_embedding, + "mrope_config": self.mrope_config, + "image": self.image, + } + + def test_to_device_basic(self): + """Test converting a basic data to device.""" + params = MultimodalParams() + params.multimodal_data = {"multimodal_embedding": self.mm_embedding} + + params.to_device("multimodal_data", device="cuda:0", pin_memory=True) + + result = params.multimodal_data["multimodal_embedding"] + self.assertEqual(result.device, torch.device("cuda:0")) + + def test_to_device_all_data(self): + """Test converting all data to device.""" + params = MultimodalParams() + params.multimodal_data = self.sample_multimodal_data.copy() + + params.to_device("multimodal_data", device="cuda:0", pin_memory=True) + + result = params.multimodal_data["multimodal_embedding"] + self.assertEqual(result.device, torch.device("cuda:0")) + + result = params.multimodal_data["mrope_config"]["mrope_rotary_cos_sin"] + self.assertEqual(result.device, torch.device("cuda:0")) + + result = params.multimodal_data["mrope_config"]["mrope_position_deltas"] + self.assertEqual(result.device, torch.device("cuda:0")) + + result = params.multimodal_data["image"]["pixel_values"] + self.assertEqual(result.device, torch.device("cuda:0")) + + if __name__ == "__main__": unittest.main()