Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def compile(

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
deallocate_module(gm, delete_module=False)
deallocate_module(exported_program.module(), delete_module=False)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
Expand Down
26 changes: 9 additions & 17 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError):


class TRTInterpreterResult(NamedTuple):
serialized_engine: bytes
engine: trt.ICudaEngine
input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
Expand Down Expand Up @@ -512,8 +512,7 @@ def _save_weight_mapping(self) -> None:
_LOGGER.info("Building weight name mapping...")
# Stage 1: Name mapping
torch_device = to_torch_device(self.compilation_settings.device)
self.module.to(torch_device)
sd = self.module.state_dict()
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
weight_name_map: dict[str, Any] = {}
weight_refit_map = self.ctx.weight_refit_map
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
Expand Down Expand Up @@ -592,13 +591,11 @@ def _save_weight_mapping(self) -> None:
torch.cuda.empty_cache()

@needs_refit # type: ignore[misc]
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
def _insert_engine_to_cache(self, hash_val: str, engine: bytes) -> None:
serialized_engine = engine.serialize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are doing this don't we end up paying the serialization cost again?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

engine.serialize() does not take any memory.

with io.BytesIO() as engine_bytes:
                engine_bytes.write(serialized_engine)
                engine_str = engine_bytes.getvalue()

This does

# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
# if not self.compilation_settings.strip_engine_weights:
# # set EXCLUDE_WEIGHTS flag to strip weights
# runtime = trt.Runtime(TRT_LOGGER)
# engine = runtime.deserialize_cuda_engine(serialized_engine)

# serialization_config = engine.create_serialization_config()
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
# serialized_engine = engine.serialize_with_config(
Expand Down Expand Up @@ -750,16 +747,15 @@ def run(
self._create_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)
serialized_engine = self.builder.build_serialized_network(

cuda_engine = self.builder.build_engine_with_config(
self.ctx.net, builder_config
)
assert serialized_engine
assert cuda_engine

_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

self.ctx.clear_cpu_weights_reference_holder()

self._save_timing_cache(
Expand All @@ -772,14 +768,10 @@ def run(
and self.compilation_settings.cache_built_engines
and self.engine_cache is not None
):
self._insert_engine_to_cache(hash_val, serialized_engine)

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()
self._insert_engine_to_cache(hash_val, cuda_engine)

return TRTInterpreterResult(
engine_str,
cuda_engine,
self._input_names,
self._output_names,
self.weight_name_map,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def convert_module(
)

return rt_cls(
serialized_engine=interpreter_result.serialized_engine,
cuda_engine=interpreter_result.engine,
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
name=name,
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def constant_fold(
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
for node, constant in cf.node_replacements.items():
replace_node_with_constant(
gm, node, torch.nn.Parameter(constant, requires_grad=False)
gm,
node,
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
)

erased_params = []
Expand Down
35 changes: 28 additions & 7 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]

def __init__(
self,
cuda_engine: trt.ICudaEngine = None,
serialized_engine: Optional[bytes] = None,
input_binding_names: Optional[List[str]] = None,
output_binding_names: Optional[List[str]] = None,
Expand Down Expand Up @@ -182,7 +183,19 @@ def __init__(
# Unused currently - to be used by Dynamic Shape support implementation
self.memory_pool = None

self.serialized_engine = serialized_engine
if cuda_engine:
assert isinstance(
cuda_engine, trt.ICudaEngine
), "Cuda engine must be a trt.ICudaEngine object"
self.engine = cuda_engine
elif serialized_engine:
assert isinstance(
serialized_engine, bytes
), "Serialized engine must be a bytes object"
self.engine = serialized_engine
else:
raise ValueError("Serialized engine or cuda engine must be provided")

self.input_names = (
input_binding_names if input_binding_names is not None else []
)
Expand All @@ -204,7 +217,6 @@ def __init__(
else False
)
self.settings = settings
self.engine = None
self.weight_name_map = weight_name_map
self.target_platform = Platform.current_platform()
self.runtime_states = TorchTRTRuntimeStates(
Expand All @@ -219,7 +231,7 @@ def __init__(
self.output_allocator: Optional[DynamicOutputAllocator] = None
self.use_output_allocator_outputs = False

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
if self.engine and not self.settings.lazy_engine_init:
self.setup_engine()

def get_streamable_device_memory_budget(self) -> Any:
Expand Down Expand Up @@ -260,13 +272,22 @@ def set_default_device_memory_budget(self) -> int:
return self._set_device_memory_budget(budget_bytes)

def setup_engine(self) -> None:

if isinstance(self.engine, trt.ICudaEngine):
pass
elif isinstance(self.engine, bytes):
runtime = trt.Runtime(TRT_LOGGER)
self.engine = runtime.deserialize_cuda_engine(self.engine)
else:
raise ValueError(
"Expected engine as trt.ICudaEngine or serialized engine as bytes"
)

assert (
self.target_platform == Platform.current_platform()
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"

self.initialized = True
runtime = trt.Runtime(TRT_LOGGER)
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
if self.settings.enable_weight_streaming:
self.set_default_device_memory_budget()
self.context = self.engine.create_execution_context()
Expand Down Expand Up @@ -302,7 +323,7 @@ def _check_initialized(self) -> None:
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")

def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None:
state_dict[prefix + "engine"] = self.serialized_engine
state_dict[prefix + "engine"] = self.engine
state_dict[prefix + "input_names"] = self.input_names
state_dict[prefix + "output_names"] = self.output_names
state_dict[prefix + "platform"] = self.target_platform
Expand All @@ -317,7 +338,7 @@ def _load_from_state_dict(
unexpected_keys: Any,
error_msgs: Any,
) -> None:
self.serialized_engine = state_dict[prefix + "engine"]
self.engine = state_dict[prefix + "engine"]
self.input_names = state_dict[prefix + "input_names"]
self.output_names = state_dict[prefix + "output_names"]
self.target_platform = state_dict[prefix + "platform"]
Expand Down
24 changes: 20 additions & 4 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import base64
import copy
import io
import logging
import pickle
from typing import Any, List, Optional, Tuple, Union

import tensorrt as trt
import torch
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import Platform
Expand Down Expand Up @@ -76,6 +78,7 @@ class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc]

def __init__(
self,
cuda_engine: Optional[trt.ICudaEngine | bytes] = None,
serialized_engine: Optional[bytes] = None,
input_binding_names: Optional[List[str]] = None,
output_binding_names: Optional[List[str]] = None,
Expand Down Expand Up @@ -123,8 +126,22 @@ def __init__(
"""
super(TorchTensorRTModule, self).__init__()

if not isinstance(serialized_engine, bytearray):
ValueError("Expected serialized engine as bytearray")
if serialized_engine:
assert isinstance(
serialized_engine, bytes
), "Serialized engine must be a bytes object"
self.serialized_engine = serialized_engine

elif cuda_engine:
assert isinstance(
cuda_engine, trt.ICudaEngine
), "Cuda engine must be a trt.ICudaEngine object"
serialized_engine = cuda_engine.serialize()
with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine) # type: ignore
self.serialized_engine = engine_bytes.getvalue()
else:
raise ValueError("Serialized engine or cuda engine must be provided")

self.input_binding_names = (
input_binding_names if input_binding_names is not None else []
Expand All @@ -136,12 +153,11 @@ def __init__(
self.hardware_compatible = settings.hardware_compatible
self.settings = copy.deepcopy(settings)
self.weight_name_map = weight_name_map
self.serialized_engine = serialized_engine
self.engine = None
self.requires_output_allocator = requires_output_allocator

if (
serialized_engine
self.serialized_engine
and not self.settings.lazy_engine_init
and not self.settings.enable_cross_compile_for_windows
):
Expand Down
Loading