Skip to content

🐛 [Bug] global partitioner does not work while compiling with dynamo #3157

@seymurkafkas

Description

@seymurkafkas

Hi all! Trying to use global partitioning fails with the dynamo backend, and couldn't pinpoint why (tried various compilation parameters).

How to Reproduce:

System:

Cuda Driver Version: 535.104.12
GPU: Nvidia Tesla T4
Python: 3.11.10

Dependencies (wheels):

https://download.pytorch.org/whl/cu121/torch-2.4.1%2Bcu121-cp311-cp311-linux_x86_64.whl
https://download.pytorch.org/whl/cu121/torch_tensorrt-2.4.0%2Bcu121-cp311-cp311-linux_x86_64.whl

Script to reproduce:

import torch
import torch_tensorrt
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 134 * 134, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

model = SimpleCNN()

def compile_to_tensorrt() -> None:
    batch_size, tile_size = 1, 538
    model = SimpleCNN().to(dtype = torch.float16, device = torch.device('cuda'))
    model.eval()
    with torch.no_grad():
        inputs = torch.randn(
            batch_size, 3, tile_size, tile_size, device="cuda", dtype=torch.float16
        )
        print("Compiling model...")
        _trt_graph_module = torch_tensorrt.compile(
            model,
            ir="dynamo",
            inputs=[inputs],
            enabled_precisions={torch.float16},
            use_fast_partitioner=False,
        )

if __name__ == "__main__":
    compile_to_tensorrt()

Error and TraceBack:

Traceback (most recent call last):
    compile_to_tensorrt()
  File "reproduce.py", line 37, in compile_to_tensorrt
    _trt_graph_module = torch_tensorrt.compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch_tensorrt/_compile.py", line 249, in compile
    trt_graph_module = dynamo_compile(
                       ^^^^^^^^^^^^^^^
  File "site-packages/torch_tensorrt/dynamo/_compiler.py", line 230, in compile
    trt_gm = compile_module(gm, inputs, settings)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "site-packages/torch_tensorrt/dynamo/_compiler.py", line 365, in compile_module
    for node in submodule.graph.nodes
                ^^^^^^^^^^^^^^^
  File "site-packages/torch/nn/modules/module.py", line 1729, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'Module' object has no attribute 'graph'

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions