-
Notifications
You must be signed in to change notification settings - Fork 370
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
To Reproduce
import os
from tempfile import gettempdir
import torch
import torch_tensorrt
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x)
with torch.inference_mode():
device = torch.device("cuda", 0)
model = MyModule().eval().to(device)
inputs1 = [torch.randn(1, 3, 224, 224, device=device)]
trt_model1 = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=inputs1,
enabled_precisions={torch.float},
device=device,
make_refitable=True,
debug=True,
min_block_size=1,
engine_cache_dir=os.path.join(gettempdir(), "torchtrt_issue3148"),
)
trt_model1(*inputs1)
print("\n========================================\n")
inputs2 = [torch.randn(2, 3, 224, 224, device=device)]
trt_model2 = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=inputs2,
enabled_precisions={torch.float},
device=device,
make_refitable=True,
debug=True,
min_block_size=1,
engine_cache_dir=os.path.join(gettempdir(), "torchtrt_issue3148"),
)
trt_model2(*inputs2)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return (relu,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return (relu,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return (relu,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return (relu,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return (relu,)
INFO:torch_tensorrt.dynamo._engine_cache:Disk engine cache initialized (cache directory:/tmp/torchtrt_issue3148, max size: 1073741824)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=True, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.relu.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.relu.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
Input shapes: [(1, 3, 224, 224)]
graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return relu
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 3, 224, 224], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /relu (kind: aten.relu.default, args: ('x <Node>',))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /relu [aten.relu.default] (Inputs: (x: (1, 3, 224, 224)@torch.float32) | Outputs: (relu: (1, 3, 224, 224)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('relu <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 3, 224, 224), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (relu: (1, 3, 224, 224)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.004144
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Building weight name mapping...
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:03.345116
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 12932 bytes of Memory
DEBUG:torch_tensorrt.dynamo._engine_cache:The engine added to cache, saved to /tmp/torchtrt_issue3148/qcp2nbn7adw2zbhxzqql4er37brkw33awbzf4jqx33pbhuvino3/blob.bin
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++
The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage
Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=True, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True)
Graph Structure:
Inputs: List[Tensor: (1, 3, 224, 224)@float32]
...
TRT Engine #1 - Submodule name: _run_on_acc_0
Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32]
Number of Operators in Engine: 1
Engine Outputs: List[Tensor: (1, 3, 224, 224)@float32]
...
Outputs: List[Tensor: (1, 3, 224, 224)@float32]
------------------------- Aggregate Stats -------------------------
Average Number of Operators per TRT Engine: 1.0
Most Operators in a TRT Engine: 1
********** Recommendations **********
- For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
========================================
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return (relu,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return (relu,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return (relu,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return (relu,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return (relu,)
INFO:torch_tensorrt.dynamo._engine_cache:Disk engine cache initialized (cache directory:/tmp/torchtrt_issue3148, max size: 1073741824)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=True, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.relu.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.relu.default + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
Input shapes: [(2, 3, 224, 224)]
graph():
%x : [num_users=1] = placeholder[target=x]
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {})
return relu
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
DEBUG:torch_tensorrt.dynamo._engine_cache:Engine found in cache, loaded from /tmp/torchtrt_issue3148/qcp2nbn7adw2zbhxzqql4er37brkw33awbzf4jqx33pbhuvino3/blob.bin
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Found the cached engine that corresponds to this graph. It is directly loaded.
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[2, 3, 224, 224], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (2, 3, 224, 224)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /relu (kind: aten.relu.default, args: ('x <Node>',))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /relu [aten.relu.default] (Inputs: (x: (2, 3, 224, 224)@torch.float32) | Outputs: (relu: (2, 3, 224, 224)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('relu <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(2, 3, 224, 224), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (relu: (2, 3, 224, 224)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.001975
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++
The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage
Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=True, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True)
Graph Structure:
Inputs: List[Tensor: (2, 3, 224, 224)@float32]
...
TRT Engine #1 - Submodule name: _run_on_acc_0
Engine Inputs: List[Tensor: (2, 3, 224, 224)@float32]
Number of Operators in Engine: 1
Engine Outputs: List[Tensor: (2, 3, 224, 224)@float32]
...
Outputs: List[Tensor: (2, 3, 224, 224)@float32]
------------------------- Aggregate Stats -------------------------
Average Number of Operators per TRT Engine: 1.0
Most Operators in a TRT Engine: 1
********** Recommendations **********
- For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
ERROR: [Torch-TensorRT] - IExecutionContext::setInputShape: Error Code 3: API Usage Error (Parameter check failed, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape for x. Set dimensions are [2,3,224,224]. Expected dimensions are [1,3,224,224].)
Traceback (most recent call last):
File "/home/holywu/test.py", line 45, in <module>
trt_model2(*inputs2)
File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 784, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 361, in __call__
raise e
File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 348, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<eval_with_key>.85", line 6, in forward
File "/home/holywu/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/_features.py", line 56, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 279, in forward
outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
return self._op(*args, **(kwargs or {}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: [Error thrown at core/runtime/execute_engine.cpp:231] Expected compiled_engine->exec_ctx->setInputShape(name.c_str(), dims) to be true but got false
Error while setting the input shape
Environment
- Torch-TensorRT Version (e.g. 1.0.0): 2.5.0.dev20240907+cu124
- PyTorch Version (e.g. 1.0): 2.5.0.dev20240907+cu124
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.12
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working