Skip to content

🐛 [Bug] ERROR: [Torch-TensorRT] - F.pad(input, self.padding, 'reflect') cannot be compiled end to end by Torch-TensorRT.TorchScript. #1704

@zshn25

Description

@zshn25

ERROR: [Torch-TensorRT] - F.pad(input, self.padding, 'reflect') cannot be compiled end to end by Torch-TensorRT.TorchScript.

To Reproduce

Steps to reproduce the behavior:

import torch
from torch import nn
import torch_tensorrt

class Padder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.pad = nn.ReflectionPad2d(1)

    def forward(self, x):
        return self.pad(x)

model = Padder().eval().to("cuda")
input_img = torch.randn((1, 3, 480, 768), requires_grad=False).to("cuda").detach()
scripted_model = torch.jit.script(model, input_img)

spec = {
    "forward": torch_tensorrt.ts.TensorRTCompileSpec(
        **{
            "inputs": [torch_tensorrt.Input([1, 3, 480, 768])],
            "enabled_precisions": {torch.float, torch.half},
            "refit": False,
            "debug": False,
            "device": {
                "device_type": torch_tensorrt.DeviceType.GPU,
                "gpu_id": 0,
                "dla_core": 0,
                "allow_gpu_fallback": True,
            },
            "capability": torch_tensorrt.EngineCapability.default,
            "num_avg_timing_iters": 1,
        }
    )
}
trt_model = torch._C._jit_to_backend("tensorrt", scripted_model, spec)
ERROR: [Torch-TensorRT] - Method requested cannot be compiled end to end by Torch-TensorRT.TorchScript.
Unsupported operators listed below:
  - aten::pad(Tensor self, int[] pad, str mode="constant", float? value=None) -> Tensor
You can either implement converters for these ops in your application or request implementation
https://www.github.com/nvidia/Torch-TensorRT/issues

In Module:

ERROR: [Torch-TensorRT] - Unsupported operator: aten::pad(Tensor self, int[] pad, str mode="constant", float? value=None) -> Tensor
  File "/home/user/miniconda3/envs/inference/lib/python3.8/site-packages/torch/nn/modules/padding.py", line 178
    def forward(self, input: Tensor) -> Tensor:
        return F.pad(input, self.padding, 'reflect')
               ~~~~~ <--- HERE

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[17], line 33
     13 scripted_model = torch.jit.script(model, input_img)
     15 spec = {
     16     "forward": torch_tensorrt.ts.TensorRTCompileSpec(
     17         **{
   (...)
     31     )
     32 }
---> 33 trt_model = torch._C._jit_to_backend("tensorrt", scripted_model, spec)

RuntimeError: [Error thrown at /workspace/project/py/torch_tensorrt/csrc/tensorrt_backend.cpp:68] Expected core::CheckMethodOperatorSupport(mod, it->key().toStringRef()) to be true but got false
Method forwardcannot be compiled by Torch-TensorRT

Expected behavior

No error

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.3.0
  • PyTorch Version (e.g. 1.0): 1.13.1
  • CPU Architecture: x64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): conda
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.8
  • CUDA version: 11.7
  • GPU models and configuration: V100
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions