Skip to content

🐛 [Bug] compiled F.grid_sample produces wrong result #3674

@epii2zero

Description

@epii2zero

Bug Description

"torch.nn.functional.grid_sample" dynamo compile leads to wrong output.

To Reproduce

import torch
from torch import nn
from torch.nn import functional as F
import torch_tensorrt


class GridSampleTest(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, in_img, grid):

    out_img = F.grid_sample(
      input=in_img,
      grid=grid,
      mode="bilinear",
      padding_mode="border",
      align_corners=True,
    )

    return out_img


def main():
  torch.manual_seed(1234)

  # input image [B, C, H, W]
  # value: normalized image with uniform 0~1
  in_img = torch.rand(1,3,32,52, device="cuda")

  # grid sampling cordinates [B, H, W, 2(x,y)]
  # value: standard gausian
  grid = torch.randn(1, 32, 52, 2, device="cuda")

  model = GridSampleTest()
  model = model.cuda()

  out_img_torch = model(in_img, grid)

  exp_model = torch.export.export(
    model,
    (in_img, grid),
    strict=False,
  )

  out_img_exp = exp_model.module()(in_img, grid)

  trt_model = torch_tensorrt.dynamo.compile(
    exp_model,
    inputs=(in_img, grid),
    enabled_precisions=[torch.float32],
    optimization_level=3,
    min_block_size=1,
    use_explicit_typing=True,
  )

  out_img_trt = trt_model(in_img, grid)

  exp_diff = torch.mean((out_img_exp - out_img_torch)**2)
  trt_diff = torch.mean((out_img_trt - out_img_torch)**2)

  print(f"{exp_diff=}")
  print(f"{trt_diff=}")


if __name__ == "__main__":
  main()
exp_diff=tensor(0., device='cuda:0')
trt_diff=tensor(0.0196, device='cuda:0')

Steps to reproduce the behavior:

  1. excute this sample code

Expected behavior

With custom compatible grid_sample implementation.
Compiled custom model produces normal result.

exp_diff=tensor(0., device='cuda:0')
trt_diff=tensor(2.4201e-12, device='cuda:0')

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.7.0
  • PyTorch Version (e.g. 1.0): 2.7.1
  • CPU Architecture: x86-64
  • OS (e.g., Linux): Ubuntu 22.04 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives: from archives
  • Python version: 3.12
  • CUDA version: 12.6
  • GPU models and configuration: RTX 2080Ti Turing
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions