Skip to content

Conversation

HolyWu
Copy link
Contributor

@HolyWu HolyWu commented Dec 31, 2024

Description

This PR fixes two issues in grid_sample.

1. PyTorch defines interpolation mode enum as "bilinear"=0 and "nearest"=1. But the converter impl has 0 and 1 reversed, causing discrepancy between Torch and Torch-TRT.

import os

import torch
import torch.nn.functional as F
import torch_tensorrt

os.environ["CI_BUILD"] = "1"


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
        return F.grid_sample(x, grid, mode="bilinear", align_corners=False)


with torch.inference_mode():
    model = MyModule().eval().cuda()

    inputs = [torch.randn(1, 3, 224, 224, device="cuda"), torch.randn(1, 224, 224, 2, device="cuda")]

    trt_model = torch_tensorrt.compile(model, "dynamo", inputs, debug=True, min_block_size=1)

    torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3)
    print("assert_close passed")
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return (grid_sampler,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return (grid_sampler_2d,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return (grid_sampler_2d,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return (grid_sampler_2d,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return (grid_sampler_2d,)
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.grid_sampler_2d.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.grid_sampler_2d.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(1, 3, 224, 224), (1, 224, 224, 2)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return grid_sampler_2d
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 3, 224, 224], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node grid (kind: grid, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: grid [shape=[1, 224, 224, 2], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node grid [grid] (Inputs: () | Outputs: (grid: (1, 224, 224, 2)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /grid_sampler_2d (kind: aten.grid_sampler_2d.default, args: ('x <Node>', 'grid <Node>', '0 <int>', '0 <int>', 'False <bool>'))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /grid_sampler_2d [aten.grid_sampler_2d.default] (Inputs: (x: (1, 3, 224, 224)@torch.float32, grid: (1, 224, 224, 2)@torch.float32, 0, 0, False) | Outputs: (grid_sampler_2d: (1, 3, 224, 224)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('grid_sampler_2d <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 3, 224, 224), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (grid_sampler_2d: (1, 3, 224, 224)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.001957
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.120683
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 12180 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 3953 microseconds.
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 80
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 0
DEBUG: [Torch-TensorRT] - - Runner scratch: 0 bytes
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: x has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Input binding name: grid has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 2, Torch binding index: 2
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
  Name: _run_on_acc_0_engine
  Inputs: [
    id: 0
      name: x
      shape: [1, 3, 224, 224]
      dtype: Float
    id: 1
      name: grid
      shape: [1, 224, 224, 2]
      dtype: Float
  ]
  Outputs: [
    id: 0
      name: output0
      shape: [1, 3, 224, 224]
      dtype: Float
  ]
  Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
  Hardware Compatibility: Disabled
  Target Platform: windows_x86_64

DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, refit_identical_engine_weights=False, strip_engine_weights=False, immutable_weights=True, enable_weight_streaming=False, enable_cross_compile_for_windows=False)

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 224, 224)@float32, Tensor: (1, 224, 224, 2)@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32, Tensor: (1, 224, 224, 2)@float32]
     Number of Operators in Engine: 1
     Engine Outputs: List[Tensor: (1, 3, 224, 224)@float32]
    ...
   Outputs: List[Tensor: (1, 3, 224, 224)@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 1.0
   Most Operators in a TRT Engine: 1

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 0
DEBUG: [Torch-TensorRT] - Input shape changed None -> (1,3,224,224)(1,224,224,2)
DEBUG: [Torch-TensorRT] - Input Name: x Shape: [1, 3, 224, 224]
DEBUG: [Torch-TensorRT] - Input Name: grid Shape: [1, 224, 224, 2]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [1, 3, 224, 224]
Traceback (most recent call last):
  File "C:\Users\HolyWu\Downloads\test.py", line 25, in <module>
    torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3)
  File "C:\Python312\Lib\site-packages\torch\testing\_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 69917 / 150528 (46.4%)
Greatest absolute difference: 3.082557439804077 at index (0, 2, 213, 105) (up to 0.005 allowed)
Greatest relative difference: 280988.21875 at index (0, 1, 184, 207) (up to 0.005 allowed)

2. PyTorch dispatches grid_sampler to cudnn_grid_sampler when mode="bilinear", padding_mode="zeros", align_corners=True, causing graph breaks due to unsupported node.

import os

import torch
import torch.nn.functional as F
import torch_tensorrt

os.environ["CI_BUILD"] = "1"


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
        return F.grid_sample(x, grid, mode="bilinear", align_corners=True)


with torch.inference_mode():
    model = MyModule().eval().cuda()

    inputs = [torch.randn(1, 3, 224, 224, device="cuda"), torch.randn(1, 224, 224, 2, device="cuda")]

    trt_model = torch_tensorrt.compile(model, "dynamo", inputs, debug=True, min_block_size=1)

    torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3)
    print("assert_close passed")
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler.default](args = (%x, %grid, 0, 0, True), kwargs = {})
    return (grid_sampler,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {})
    return (cudnn_grid_sampler,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {})
    return (cudnn_grid_sampler,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {})
    return (cudnn_grid_sampler,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {})
    return (cudnn_grid_sampler,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten.cudnn_grid_sampler.default + Operator Count: 1

WARNING:torch_tensorrt.dynamo._compiler:0 supported operations detected in subgraph containing 1 computational nodes. Skipping this subgraph, since min_block_size was detected to be 1
assert_close passed

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Dec 31, 2024
@github-actions github-actions bot requested a review from gs-olive December 31, 2024 18:04
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@peri044 peri044 requested review from apbose and removed request for gs-olive January 2, 2025 22:12
@apbose
Copy link
Collaborator

apbose commented Jan 21, 2025

Looks good!

@apbose
Copy link
Collaborator

apbose commented Mar 14, 2025

@HolyWu could you please rebase and run the CI?

@HolyWu
Copy link
Contributor Author

HolyWu commented Mar 14, 2025

@apbose Rebased. Please re-run the CI.

@HolyWu
Copy link
Contributor Author

HolyWu commented Apr 4, 2025

@apbose I think there is no additional change needed and it's good to merge. Can we merge it soon?

@apbose apbose merged commit 76a776e into pytorch:main Apr 4, 2025
2 checks passed
@HolyWu HolyWu deleted the fix_grid_sample branch April 4, 2025 02:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants