Skip to content

🐛 [Bug] Importing torchao first breaks torch_tensorrt.dynamo.compile during run_decompositions #3226

@dgcnz

Description

@dgcnz

Bug Description

Importing torchao before importing torch_tensorrt causes F.interpolate to fail during run_decompositions with:

AssertionError: Expected aten.upsample_nearest2d.default to have CompositeImplicitAutograd kernel

To Reproduce

import torchao
import torch_tensorrt
import torch
import torch.nn.functional as F

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        _, _, h, w = x.shape
        z = F.interpolate(x, [h*2, w*2])
        return z

model = Model().cuda()
inputs = (torch.randn((2, 4, 8, 8)).cuda(),)

with torch.no_grad():
    ep = torch.export.export(
        model,
        args=inputs,
        strict=True
    )
    with torch_tensorrt.logging.debug():
        trt_gm = torch_tensorrt.dynamo.compile(
            ep,
            inputs,
            reuse_cached_engines=False,
            cache_built_engines=False,
            require_full_compilation=True,
            min_block_size=1,
        )  

Logs:

WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/projects/scripts/mre/torchao_tensorrt_import.py", line 25, in <module>
    trt_gm = torch_tensorrt.dynamo.compile(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 228, in compile
    exported_program = exported_program.run_decompositions(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 116, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 1111, in run_decompositions
    return _decompose_exported_program(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 654, in _decompose_exported_program
    gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 446, in _decompose_and_get_gm_with_new_signature_constants
    gm, graph_signature = aot_export_module(
                          ^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1262, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1497, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 524, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 625, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 194, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 859, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6495, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 275, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_higher_order_ops/utils.py", line 64, in inner
    return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_higher_order_ops/utils.py", line 37, in autograd_not_implemented_inner
    result = operator(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 449, in __torch_dispatch__
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose
    return self.py_kernels[dk](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 3184, in upsample_nearest2d_vec
    return torch.ops.aten.upsample_nearest2d.default(input, osize)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 449, in __torch_dispatch__
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose
    return self.py_kernels[dk](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_decomp/__init__.py", line 376, in _special_op_to_decompose_cia
    raise AssertionError(
AssertionError: Expected aten.upsample_nearest2d.default to have CompositeImplicitAutograd kernel

While executing %upsample_nearest2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_nearest2d.vec](args = (%x, [16, 16], None), kwargs = {})
Original traceback:
  File "/projects/scripts/mre/torchao_tensorrt_import.py", line 12, in forward
    z = F.interpolate(x, [h*2, w*2])

Expected behavior

The import order betweentorch_tensorrt and torchao should not matter.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.6.0.dev20241008+cu124
  • PyTorch Version (e.g. 1.0): 2.6.0.dev20241009+cu124
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): -
  • Are you using local sources or building from archives: -
  • Python version: 3.12.7
  • CUDA version: 12.4
  • GPU models and configuration: NVIDIA RTX 4060 TI
  • Any other relevant information:
  • torchao==0.6.0.dev20241009+cu124

Additional context

If you import torch_tensorrt first and then torchao the error disappears.

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions