-
Notifications
You must be signed in to change notification settings - Fork 369
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Hi all! Trying to use global partitioning fails with the dynamo backend, and couldn't pinpoint why (tried various compilation parameters).
How to Reproduce:
System:
Cuda Driver Version: 535.104.12
GPU: Nvidia Tesla T4
Python: 3.11.10
Dependencies (wheels):
https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl
https://download.pytorch.org/whl/cu121/torch_tensorrt-2.4.0%2Bcu121-cp311-cp311-linux_x86_64.whl
Script to reproduce:
import torch
import torch_tensorrt
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(32 * 134 * 134, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
model = SimpleCNN()
def compile_to_tensorrt() -> None:
batch_size, tile_size = 1, 538
model = SimpleCNN().to(dtype = torch.float16, device = torch.device('cuda'))
model.eval()
with torch.no_grad():
inputs = torch.randn(
batch_size, 3, tile_size, tile_size, device="cuda", dtype=torch.float16
)
print("Compiling model...")
_trt_graph_module = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[inputs],
enabled_precisions={torch.float16},
use_fast_partitioner=False,
)
if __name__ == "__main__":
compile_to_tensorrt()
Error and TraceBack:
Traceback (most recent call last):
compile_to_tensorrt()
File "reproduce.py", line 37, in compile_to_tensorrt
_trt_graph_module = torch_tensorrt.compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch_tensorrt/_compile.py", line 249, in compile
trt_graph_module = dynamo_compile(
^^^^^^^^^^^^^^^
File "site-packages/torch_tensorrt/dynamo/_compiler.py", line 230, in compile
trt_gm = compile_module(gm, inputs, settings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "site-packages/torch_tensorrt/dynamo/_compiler.py", line 365, in compile_module
for node in submodule.graph.nodes
^^^^^^^^^^^^^^^
File "site-packages/torch/nn/modules/module.py", line 1729, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'Module' object has no attribute 'graph'
dgcnz, NicolaGugole and orioninthesky98
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working