-
Notifications
You must be signed in to change notification settings - Fork 369
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
"torch.nn.functional.grid_sample" dynamo compile leads to wrong output.
To Reproduce
import torch
from torch import nn
from torch.nn import functional as F
import torch_tensorrt
class GridSampleTest(nn.Module):
def __init__(self):
super().__init__()
def forward(self, in_img, grid):
out_img = F.grid_sample(
input=in_img,
grid=grid,
mode="bilinear",
padding_mode="border",
align_corners=True,
)
return out_img
def main():
torch.manual_seed(1234)
# input image [B, C, H, W]
# value: normalized image with uniform 0~1
in_img = torch.rand(1,3,32,52, device="cuda")
# grid sampling cordinates [B, H, W, 2(x,y)]
# value: standard gausian
grid = torch.randn(1, 32, 52, 2, device="cuda")
model = GridSampleTest()
model = model.cuda()
out_img_torch = model(in_img, grid)
exp_model = torch.export.export(
model,
(in_img, grid),
strict=False,
)
out_img_exp = exp_model.module()(in_img, grid)
trt_model = torch_tensorrt.dynamo.compile(
exp_model,
inputs=(in_img, grid),
enabled_precisions=[torch.float32],
optimization_level=3,
min_block_size=1,
use_explicit_typing=True,
)
out_img_trt = trt_model(in_img, grid)
exp_diff = torch.mean((out_img_exp - out_img_torch)**2)
trt_diff = torch.mean((out_img_trt - out_img_torch)**2)
print(f"{exp_diff=}")
print(f"{trt_diff=}")
if __name__ == "__main__":
main()
exp_diff=tensor(0., device='cuda:0')
trt_diff=tensor(0.0196, device='cuda:0')
Steps to reproduce the behavior:
- excute this sample code
Expected behavior
With custom compatible grid_sample implementation.
Compiled custom model produces normal result.
exp_diff=tensor(0., device='cuda:0')
trt_diff=tensor(2.4201e-12, device='cuda:0')
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.7.0
- PyTorch Version (e.g. 1.0): 2.7.1
- CPU Architecture: x86-64
- OS (e.g., Linux): Ubuntu 22.04 LTS
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives: from archives
- Python version: 3.12
- CUDA version: 12.6
- GPU models and configuration: RTX 2080Ti Turing
- Any other relevant information:
Additional context
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working