Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2463,3 +2463,45 @@ def aten_ops_pad(
mode=args_bounds_check(args, 2, "constant"),
value=args_bounds_check(args, 3, None),
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.vec)
def upsample_nearest2d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
out_shape=args_bounds_check(args, 1),
scale_factors=args_bounds_check(args, 2),
resize_mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.vec)
def upsample_bilinear2d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
out_shape=args_bounds_check(args, 1),
scale_factors=args_bounds_check(args, 3),
resize_mode="bilinear",
align_corners=args_bounds_check(args, 2),
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@
topk,
unary,
unsqueeze,
upsample,
)
67 changes: 67 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Optional, Sequence

import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor


def upsample(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
out_shape: Optional[Sequence[int]],
scale_factors: Optional[Sequence[float]],
resize_mode: str,
align_corners: bool,
) -> TRTTensor:
resize_layer = ctx.net.add_resize(input)
# output size calculation
# Pytorch assumes that one of out_shape/scale_factor is None
# Pytorch assumes that dimensions match for out_shape/scale factor
if out_shape is not None:
resize_layer.shape = list(input.shape)[:2] + list(out_shape)
elif scale_factors is not None:
resize_layer.scales = [1.0, 1.0] + list(scale_factors)
else:
raise RuntimeError(
f"At least one of out_shape and scale_factors should be specified."
)

# interpolate mode
if resize_mode == "nearest" or None:
resize_layer.resize_mode = trt.ResizeMode.NEAREST
elif resize_mode == "bilinear":
resize_layer.resize_mode = trt.ResizeMode.LINEAR
if align_corners is None or not align_corners:
raise RuntimeError(
f"Interpolation works differently is align_corners is False for {resize_mode} mode in PyTorch and TensorRT."
)
else:
raise RuntimeError(
f"Interpolation mode is {resize_mode} which is not supported by TensorRT."
)

if resize_mode == "nearest":
resize_layer.coordinate_transformation = (
trt.ResizeCoordinateTransformation.ASYMMETRIC
)
elif resize_mode == "bilinear":
# align corners
if align_corners is not None and align_corners:
resize_layer.coordinate_transformation = (
trt.ResizeCoordinateTransformation.ALIGN_CORNERS
)
else:
resize_layer.coordinate_transformation = (
trt.ResizeCoordinateTransformation.ASYMMETRIC
)

set_layer_name(resize_layer, target, name, source_ir)

out = resize_layer.get_output(0)
return out
97 changes: 97 additions & 0 deletions tests/py/dynamo/conversion/test_upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestUpsampleConverter(DispatchTestCase):
# test case for nearest upsample, using output_size, scale_factors is disabled here
@parameterized.expand(
[
("upsample_nearest2d.vec_outshape_0", (2, 2), (4, 4)),
("upsample_nearest2d.vec_outshape_1", (2, 2), (5, 5)),
]
)
def test_upsample_nearest_output_shape(self, _, input_shape, output_shape):
class Upsample(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_nearest2d.vec(input, output_shape, None)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)

# test case for nearest upsample, using scale_factors, output_size is disabled here
@parameterized.expand(
[
("upsample_nearest2d.vec_scale_0", (2, 2), (2, 2)),
("upsample_nearest2d.vec_scale_1", (2, 2), (1.5, 1.5)),
]
)
def test_upsample_nearest_scale_factor(self, _, input_shape, scale_factor):
class Upsample(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_nearest2d.vec(input, None, scale_factor)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)

# test case for bilinear upsample, using output_size, scale_factors is disabled here
@parameterized.expand(
[
("upsample_bilinear2d.vec_outshape_0", (2, 2), (4, 4), True),
("upsample_bilinear2d.vec_outshape_1", (2, 2), (5, 5), True),
]
)
def test_upsample_bilinear_output_shape(
self, _, input_shape, output_shape, align_corners
):
class Upsample(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_bilinear2d.vec(
input,
output_shape,
align_corners,
None,
)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)

# test case for bilinear upsample, using scale_factors, output_shape is disabled here
@parameterized.expand(
[
("upsample_bilinear2d.vec_scale_0", (2, 2), (2, 2), True),
("upsample_bilinear2d.vec_scale_1", (2, 2), (1.5, 1.5), True),
]
)
def test_upsample_bilinear_scale_factors(
self, _, input_shape, scale_factors, align_corners
):
class Upsample(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_bilinear2d.vec(
input,
None,
align_corners,
scale_factors,
)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)


if __name__ == "__main__":
run_tests()