diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 8e18a3ae32..2bdfe6fb6b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -169,14 +169,13 @@ def __init__( multi_gpu_device_check() self.name = name - self._input_buffers: List[torch.Tensor] = [] - self._output_buffers: List[torch.Tensor] = [] - self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self._input_buffers: Dict[str, List[torch.Tensor]] = {} + self._output_buffers: Dict[str, List[torch.Tensor]] = {} self._caller_stream: Optional[torch.cuda.Stream] = None self._engine_stream: Optional[torch.cuda.Stream] = None # TODO: Make the below a Dictionary {shape: cudagraph} - self.shape_key: Optional[str] = None + self.shape_key_to_cudagraph: Dict[str, torch.cuda.CUDAGraph] = {} # See https://github.com/pytorch/pytorch/blob/acfe237a71af609e837a34bb38048aa8acb8eb4d/torch/cuda/graphs.py#L92-L98 # Unused currently - to be used by Dynamic Shape support implementation @@ -293,9 +292,6 @@ def setup_engine(self) -> None: if self.requires_output_allocator: self.create_output_allocator() - if torch_tensorrt.runtime.get_cudagraphs_mode(): - self.cudagraph = torch.cuda.CUDAGraph() - def _check_initialized(self) -> None: if not self.initialized: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") @@ -342,10 +338,13 @@ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule: result.__setstate__(self.__getstate__()) return result - def _reset_captured_graph(self) -> None: - if self.cudagraph: - self.cudagraph.reset() - self.cudagraph = None + def _reset_captured_graph(self, inputs_shape_key: str | None = None) -> None: + if ( + inputs_shape_key is not None + and inputs_shape_key in self.shape_key_to_cudagraph + ): + self.shape_key_to_cudagraph[inputs_shape_key].reset() + self.shape_key_to_cudagraph.pop(inputs_shape_key) def __del__(self) -> None: self._reset_captured_graph() @@ -355,6 +354,7 @@ def setup_input_tensors( contiguous_inputs: List[torch.Tensor], cudagraphs_enabled: bool, need_cudagraphs_record: bool, + inputs_shape_key: str | None = None, ) -> None: for i, input_name in enumerate(self.input_names): if not contiguous_inputs[i].is_cuda: @@ -374,14 +374,22 @@ def setup_input_tensors( contiguous_inputs[i].dtype == self.input_dtypes[i] ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." + is_shape_tensor_input = self.engine.is_shape_inference_io(input_name) if need_cudagraphs_record: # If cudagraphs is enabled, this memory is reserved for future cudagraph runs # Clone is required to avoid re-using user-provided GPU memory - self._input_buffers[i] = contiguous_inputs[i].clone() + if is_shape_tensor_input: + self._input_buffers[inputs_shape_key][i] = ( + contiguous_inputs[i].cpu().clone() + ) + else: + self._input_buffers[inputs_shape_key][i] = contiguous_inputs[ + i + ].clone() # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers # as per TensorRT requirements - if self.engine.is_shape_inference_io(input_name): + if is_shape_tensor_input: # Shape tensor inputs are casted to int64 explicitly # Currently Torch CPU pointers are not working; numpy pointers are used instead # to refer to underlying memory @@ -392,9 +400,9 @@ def setup_input_tensors( input_name, tuple(contiguous_inputs[i].shape) ) if cudagraphs_enabled: - self._input_buffers[i].copy_(contiguous_inputs[i]) + self._input_buffers[inputs_shape_key][i].copy_(contiguous_inputs[i]) self.context.set_tensor_address( - input_name, self._input_buffers[i].data_ptr() + input_name, self._input_buffers[inputs_shape_key][i].data_ptr() ) else: self.context.set_tensor_address( @@ -430,7 +438,7 @@ def create_output_allocator(self) -> None: def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: - shape_changed = self.validate_input_shapes(inputs) + shape_changed, inputs_shape_key = self.validate_input_shapes(inputs) ( need_cudagraphs_record, can_use_pre_allocated_outputs, @@ -440,11 +448,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ) if need_cudagraphs_reset: - self._reset_captured_graph() + self._reset_captured_graph(inputs_shape_key) if need_cudagraphs_record: - self._input_buffers = [None] * len(self.input_names) - self._output_buffers = [None] * len(self.output_names) + self._input_buffers[inputs_shape_key] = [None] * len(self.input_names) + self._output_buffers[inputs_shape_key] = [None] * len(self.output_names) with ( torch.autograd.profiler.record_function( @@ -458,7 +466,10 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." self.setup_input_tensors( - contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record + contiguous_inputs, + self.cudagraphs_enabled, + need_cudagraphs_record, + inputs_shape_key, ) if shape_changed: @@ -492,11 +503,12 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: for o, output_name in enumerate(self.output_names): if need_cudagraphs_record: - self._output_buffers[o] = outputs[o].clone() + self._output_buffers[inputs_shape_key][o] = outputs[o].clone() if self.cudagraphs_enabled: self.context.set_tensor_address( - output_name, self._output_buffers[o].data_ptr() + output_name, + self._output_buffers[inputs_shape_key][o].data_ptr(), ) else: self.context.set_tensor_address( @@ -522,24 +534,31 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: with torch.cuda.stream(self._engine_stream): if self.cudagraphs_enabled: if need_cudagraphs_record: - self.cudagraph = torch.cuda.CUDAGraph() + self.shape_key_to_cudagraph[inputs_shape_key] = ( + torch.cuda.CUDAGraph() + ) if self.profiling_enabled: - self.cudagraph.enable_debug_mode() + self.shape_key_to_cudagraph[ + inputs_shape_key + ].enable_debug_mode() with torch.cuda.graph( - self.cudagraph, stream=self._engine_stream + self.shape_key_to_cudagraph[inputs_shape_key], + stream=self._engine_stream, ): self.context.execute_async_v3( self._engine_stream.cuda_stream ) if self.profiling_enabled: - self.cudagraph.debug_dump( + self.shape_key_to_cudagraph[ + inputs_shape_key + ].debug_dump( f"{DEBUG_LOGGING_DIR}/{self.name}_cudagraph.dot" ) - self.cudagraph.replay() # type: ignore + self.shape_key_to_cudagraph[inputs_shape_key].replay() # type: ignore else: self.context.execute_async_v3(self._engine_stream.cuda_stream) @@ -551,7 +570,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.cudagraphs_enabled: for idx, o in enumerate(outputs): - o.copy_(self._output_buffers[idx]) + o.copy_(self._output_buffers[inputs_shape_key][idx]) if len(outputs) == 1: return outputs[0] @@ -742,27 +761,26 @@ def get_layer_info(self) -> str: ) return engine_json - def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: + def validate_input_shapes( + self, inputs: Sequence[torch.Tensor | Any] + ) -> Tuple[bool, str]: """ Validates the input shapes of the forward function has changed """ # Representation of input shapes to a given model # Shapes are concatenated as so: # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) - tensor_inputs = [] - for t in inputs: - if not isinstance(t, torch.Tensor): - return True - tensor_inputs.append(t) new_shape_key = "".join( - str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs + str( + tuple(t.shape if hasattr(t, "shape") else torch.tensor(t).shape) + ).replace(" ", "") + for t in inputs ) - # If the new shape key differs from the existing one, - # invalidate the old shape key and remove the CUDAGraph - if new_shape_key != self.shape_key: - logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}") - self.shape_key = new_shape_key - return True + if new_shape_key not in self.shape_key_to_cudagraph: + logger.debug( + f"The user provided input shape {new_shape_key} is not found in recorded CUDAGraph input shapes. A new CUDAGraph will be recorded with this input shape." + ) + return True, new_shape_key - return False + return False, new_shape_key diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7e50b515c2..fc3f22843e 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -175,7 +175,7 @@ def measure_perf(trt_model, input_signature, backend_name): arg_parser.add_argument( "--model", type=str, - default="meta-llama/Llama-3.2-1B-Instruct", + default="Qwen/Qwen2.5-0.5B-Instruct", help="Name of LLM model", ) arg_parser.add_argument(