Skip to content

Commit 50678a5

Browse files
Some fixes for the dynamic memory setting (#3729)
Co-authored-by: Adrian Wang <[email protected]>
1 parent e22be64 commit 50678a5

File tree

11 files changed

+83
-72
lines changed

11 files changed

+83
-72
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ TRTEngine::TRTEngine(
6262
bool hardware_compatible,
6363
bool requires_output_allocator,
6464
const std::string& serialized_metadata,
65-
const ResourceAllocationStrategy& resource_allocation_strategy)
65+
const ResourceAllocationStrategy resource_allocation_strategy)
6666
: TRTEngine(
6767
"deserialized_trt",
6868
serialized_engine,
@@ -86,7 +86,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
8686
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
8787
static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),
8888
serialized_info[SERIALIZED_METADATA_IDX],
89-
resource_allocation_strategy_from_string(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) {}
89+
(static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? ResourceAllocationStrategy::kDynamic : ResourceAllocationStrategy::kStatic)) {}
9090

9191
TRTEngine::TRTEngine(
9292
const std::string& mod_name,
@@ -98,7 +98,7 @@ TRTEngine::TRTEngine(
9898
bool hardware_compatible,
9999
bool requires_output_allocator,
100100
const std::string& serialized_metadata,
101-
const ResourceAllocationStrategy& resource_allocation_strategy) {
101+
const ResourceAllocationStrategy resource_allocation_strategy) {
102102
TORCHTRT_CHECK(
103103
is_supported_on_current_platform(target_platform),
104104
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
@@ -128,9 +128,11 @@ TRTEngine::TRTEngine(
128128
cuda_engine->setWeightStreamingBudgetV2(budget_bytes);
129129
}
130130

131+
this->resource_allocation_strategy = resource_allocation_strategy;
132+
LOG_DEBUG("Resource allocation strategy: " << (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static"));
131133
if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {
132134
this->exec_ctx =
133-
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE));
135+
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
134136
} else {
135137
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
136138
}
@@ -402,6 +404,7 @@ std::string TRTEngine::to_str() const {
402404
ss << " Device: " << device_info << std::endl;
403405
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
404406
ss << " Target Platform: " << target_platform << std::endl;
407+
ss << " Resource Allocation Strategy: " << (resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "Dynamic" : "Static") << std::endl;
405408
// clang-format on
406409
return ss.str();
407410
}
@@ -469,8 +472,7 @@ std::vector<std::string> TRTEngine::serialize() {
469472
serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0";
470473
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
471474
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
472-
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =
473-
resource_allocation_strategy_to_string(this->resource_allocation_strategy);
475+
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ? "1" : "0";
474476

475477
return serialized_info;
476478
}
@@ -483,11 +485,12 @@ void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationSt
483485
if (new_strategy != this->resource_allocation_strategy) {
484486
this->resource_allocation_strategy = new_strategy;
485487
if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
486-
std::cout << "Setting resource allocation strategy to dynamic" << std::endl;
487-
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
488+
LOG_DEBUG("Setting resource allocation strategy to dynamic");
489+
this->exec_ctx = make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
488490
} else {
491+
LOG_DEBUG("Setting resource allocation strategy to static");
489492
this->exec_ctx = make_trt(
490-
cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE));
493+
cuda_engine->createExecutionContext());
491494
}
492495
}
493496
}

core/runtime/TRTEngine.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class DynamicOutputAllocator : public nvinfer1::IOutputAllocator {
100100

101101
struct TRTEngine : torch::CustomClassHolder {
102102
// Resource Allocation Strategy
103-
enum ResourceAllocationStrategy { kStatic, kDynamic };
103+
typedef enum { kStatic = 0, kDynamic } ResourceAllocationStrategy;
104104
// Each engine needs it's own runtime object
105105
std::shared_ptr<nvinfer1::IRuntime> rt;
106106
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
@@ -132,7 +132,7 @@ struct TRTEngine : torch::CustomClassHolder {
132132
bool hardware_compatible = false,
133133
bool requires_output_allocator = false,
134134
const std::string& serialized_metadata = "",
135-
const TRTEngine::ResourceAllocationStrategy& resource_allocation_strategy =
135+
const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy =
136136
TRTEngine::ResourceAllocationStrategy::kStatic);
137137

138138
TRTEngine(std::vector<std::string> serialized_info);
@@ -147,7 +147,7 @@ struct TRTEngine : torch::CustomClassHolder {
147147
bool hardware_compatible = false,
148148
bool requires_output_allocator = false,
149149
const std::string& serialized_metadata = "",
150-
const TRTEngine::ResourceAllocationStrategy& resource_allocation_strategy =
150+
const TRTEngine::ResourceAllocationStrategy resource_allocation_strategy =
151151
TRTEngine::ResourceAllocationStrategy::kStatic);
152152

153153
TRTEngine& operator=(const TRTEngine& other);

core/runtime/register_jit_hooks.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,6 @@ std::string serialize_bindings(const std::vector<std::string>& bindings) {
2222
return serialized_binding_info;
2323
}
2424

25-
std::string resource_allocation_strategy_to_string(TRTEngine::ResourceAllocationStrategy strategy) {
26-
if (strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
27-
return std::string("kDynamic");
28-
} else {
29-
return std::string("kStatic");
30-
}
31-
}
32-
33-
TRTEngine::ResourceAllocationStrategy resource_allocation_strategy_from_string(const std::string& str) {
34-
if (str == "kDynamic")
35-
return TRTEngine::ResourceAllocationStrategy::kDynamic;
36-
else
37-
return TRTEngine::ResourceAllocationStrategy::kStatic;
38-
}
39-
4025
static const std::string sym_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; //=
4126
std::string base64_encode(const std::string& in) {
4227
std::string out;
@@ -106,7 +91,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
10691
.def("infer_outputs", &TRTEngine::infer_outputs)
10792
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
10893
.def(
109-
"_use_dynamically_allocated_resources",
94+
"use_dynamically_allocated_resources",
11095
[](const c10::intrusive_ptr<TRTEngine>& self, bool dynamic) -> void {
11196
self->set_resource_allocation_strategy(
11297
dynamic ? TRTEngine::ResourceAllocationStrategy::kDynamic
@@ -124,6 +109,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
124109
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> { return self->serialize(); },
125110
[](std::vector<std::string> serialized_info) -> c10::intrusive_ptr<TRTEngine> {
126111
serialized_info[ENGINE_IDX] = base64_decode(serialized_info[ENGINE_IDX]);
112+
LOG_DEBUG("Deserialized resource allocation strategy: " << (static_cast<bool>(std::stoi(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) ? "Dynamic" : "Static"));
127113
TRTEngine::verify_serialization_fmt(serialized_info);
128114
return c10::make_intrusive<TRTEngine>(serialized_info);
129115
});

examples/dynamo/dynamic_memory_allocation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44
import torch_tensorrt as torch_trt
55
import torchvision.models as models
6-
from diffusers import DiffusionPipeline
6+
import time
7+
import gc
78

89
np.random.seed(5)
910
torch.manual_seed(5)
@@ -14,23 +15,28 @@
1415
"use_python_runtime": False,
1516
"enabled_precisions": {torch.float32},
1617
"immutable_weights": False,
18+
"lazy_engine_init": True,
19+
"dynamically_allocate_resources": True
20+
1721
}
1822

1923
model = models.resnet152(pretrained=True).eval().to("cuda")
2024
compiled_module = torch_trt.compile(model, inputs=inputs, **settings)
2125
print((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3)
2226
compiled_module(*inputs)
2327

24-
breakpoint()
25-
with torch_trt.dynamo.runtime.ResourceAllocatorContext(compiled_module):
28+
time.sleep(30)
29+
with torch_trt.dynamo.runtime.ResourceAllocationStrategy(compiled_module, dynamically_allocate_resources=False):
2630
print(
2731
"Memory used (GB):",
2832
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
2933
)
30-
breakpoint()
3134
compiled_module(*inputs)
35+
gc.collect()
36+
torch.cuda.empty_cache()
37+
time.sleep(30)
3238
print(
3339
"Memory used (GB):",
3440
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
3541
)
36-
breakpoint()
42+
compiled_module(*inputs)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def cross_compile_for_windows(
103103
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
104104
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
105105
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
106+
dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES,
106107
**kwargs: Any,
107108
) -> torch.fx.GraphModule:
108109
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -177,6 +178,7 @@ def cross_compile_for_windows(
177178
enable_weight_streaming (bool): Enable weight streaming.
178179
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
179180
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
181+
dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
180182
**kwargs: Any,
181183
Returns:
182184
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -340,6 +342,7 @@ def cross_compile_for_windows(
340342
"enable_weight_streaming": enable_weight_streaming,
341343
"tiling_optimization_level": tiling_optimization_level,
342344
"l2_limit_for_tiling": l2_limit_for_tiling,
345+
"dynamically_allocate_resources": dynamically_allocate_resources,
343346
}
344347

345348
# disable the following settings is not supported for cross compilation for windows feature
@@ -440,6 +443,7 @@ def compile(
440443
tiling_optimization_level: str = _defaults.TILING_OPTIMIZATION_LEVEL,
441444
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
442445
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
446+
dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES,
443447
**kwargs: Any,
444448
) -> torch.fx.GraphModule:
445449
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -517,6 +521,7 @@ def compile(
517521
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
518522
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
519523
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
524+
dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
520525
**kwargs: Any,
521526
Returns:
522527
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -690,6 +695,7 @@ def compile(
690695
"tiling_optimization_level": tiling_optimization_level,
691696
"l2_limit_for_tiling": l2_limit_for_tiling,
692697
"offload_module_to_cpu": offload_module_to_cpu,
698+
"dynamically_allocate_resources": dynamically_allocate_resources,
693699
}
694700

695701
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
L2_LIMIT_FOR_TILING = -1
5858
USE_DISTRIBUTED_MODE_TRACE = False
5959
OFFLOAD_MODULE_TO_CPU = False
60+
DYNAMICALLY_ALLOCATE_RESOURCES = False
6061

6162
if platform.system() == "Linux":
6263
import pwd

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
DLA_GLOBAL_DRAM_SIZE,
1212
DLA_LOCAL_DRAM_SIZE,
1313
DLA_SRAM_SIZE,
14+
DYNAMICALLY_ALLOCATE_RESOURCES,
1415
DRYRUN,
1516
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
1617
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
@@ -97,6 +98,8 @@ class CompilationSettings:
9798
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
9899
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
99100
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
101+
offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation
102+
dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines
100103
"""
101104

102105
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -140,6 +143,7 @@ class CompilationSettings:
140143
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
141144
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
142145
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
146+
dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES
143147

144148
def __getstate__(self) -> dict[str, Any]:
145149
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (

py/torch_tensorrt/dynamo/runtime/_ResourceAllocator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import torch
44

55

6-
class ResourceAllocatorContext(torch.nn.Module): # type: ignore[misc]
6+
class ResourceAllocationStrategy(torch.nn.Module): # type: ignore[misc]
77
"""
8-
ResourceAllocatorContext is a context manager module that temporarily enables dynamic resource allocation
8+
ResourceAllocationStrategy is a context manager module that temporarily enables dynamic resource allocation
99
for all TRT submodules of the given compiled_module. When entering the context,
1010
it sets these submodules to use dynamically allocated resources. Upon exiting, it restores them to their
1111
original (static) resource allocation mode.
@@ -14,17 +14,19 @@ class ResourceAllocatorContext(torch.nn.Module): # type: ignore[misc]
1414
def __init__(
1515
self,
1616
compiled_module: torch.nn.Module,
17+
dynamically_allocate_resources: bool = True
1718
) -> None:
18-
super(ResourceAllocatorContext, self).__init__()
19+
super(ResourceAllocationStrategy, self).__init__()
1920
self.compiled_module = compiled_module
21+
self.dynamically_allocate_resources = dynamically_allocate_resources
2022

2123
def __enter__(self) -> None:
2224
print("Entering resource allocator context")
2325
for name, submodule in self.compiled_module.named_modules():
2426
if "_run_on_acc" in name:
25-
submodule.use_dynamically_allocated_resources(dynamic=True)
27+
submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources)
2628

2729
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
2830
for name, submodule in self.compiled_module.named_modules():
2931
if "_run_on_acc" in name:
30-
submodule.use_dynamically_allocated_resources(dynamic=False)
32+
submodule.use_dynamically_allocated_resources(dynamically_allocate_resources=self.dynamically_allocate_resources)

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
self.serialized_engine = serialized_engine
143143
self.engine = None
144144
self.requires_output_allocator = requires_output_allocator
145-
self.resource_allocation_strategy = 0 # Default to static allocation TODO: Make this configurable with the context manager
145+
self.dynamically_allocate_resources = settings.dynamically_allocate_resources
146146

147147
if (
148148
serialized_engine
@@ -188,9 +188,11 @@ def _pack_engine_info(self) -> List[str | bytes]:
188188
engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = str(
189189
int(self.requires_output_allocator)
190190
)
191+
print(f"PROVIDED RESOURCE ALLOCATION STRATEGY: {self.dynamically_allocate_resources}")
191192
engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str(
192-
int(self.resource_allocation_strategy)
193+
int(self.dynamically_allocate_resources)
193194
)
195+
print(engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX])
194196

195197
return engine_info
196198

@@ -219,8 +221,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
219221
def _reset_captured_graph(self) -> None:
220222
self.engine.reset_captured_graph()
221223

222-
def use_dynamically_allocated_resources(self, dynamic: bool = False) -> None:
223-
self.engine._use_dynamically_allocated_resources(dynamic)
224+
def use_dynamically_allocated_resources(self, dynamically_allocate_resources: bool = False) -> None:
225+
self.dynamically_allocate_resources = dynamically_allocate_resources
226+
self.engine.use_dynamically_allocated_resources(self.dynamically_allocate_resources)
224227

225228
def setup_engine(self) -> None:
226229
"""

py/torch_tensorrt/dynamo/runtime/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
PythonTorchTensorRTModule,
44
)
55
from torch_tensorrt.dynamo.runtime._ResourceAllocator import ( # noqa: F401
6-
ResourceAllocatorContext,
6+
ResourceAllocationStrategy,
77
)
88
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401
99
TorchTensorRTModule,

0 commit comments

Comments
 (0)