From 960458f62682c3fed18bcae355f9bcab93157c50 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 8 Dec 2023 17:49:34 -0800 Subject: [PATCH] feat: expose IResizeLayer in dynamo --- .../dynamo/conversion/aten_ops_converters.py | 42 ++++++++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/upsample.py | 67 +++++++++++++ tests/py/dynamo/conversion/test_upsample.py | 97 +++++++++++++++++++ 4 files changed, 207 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/upsample.py create mode 100644 tests/py/dynamo/conversion/test_upsample.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8eb07c07a7..4d9547d3ed 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2463,3 +2463,45 @@ def aten_ops_pad( mode=args_bounds_check(args, 2, "constant"), value=args_bounds_check(args, 3, None), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.vec) +def upsample_nearest2d( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.upsample.upsample( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + out_shape=args_bounds_check(args, 1), + scale_factors=args_bounds_check(args, 2), + resize_mode="nearest", + align_corners=False, + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.vec) +def upsample_bilinear2d( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.upsample.upsample( + ctx, + target, + SourceIR.ATEN, + name, + input=args[0], + out_shape=args_bounds_check(args, 1), + scale_factors=args_bounds_check(args, 3), + resize_mode="bilinear", + align_corners=args_bounds_check(args, 2), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 5bace705cb..ca71cb0b0c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -28,4 +28,5 @@ topk, unary, unsqueeze, + upsample, ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py new file mode 100644 index 0000000000..3313730ec3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -0,0 +1,67 @@ +from typing import Optional, Sequence + +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def upsample( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + out_shape: Optional[Sequence[int]], + scale_factors: Optional[Sequence[float]], + resize_mode: str, + align_corners: bool, +) -> TRTTensor: + resize_layer = ctx.net.add_resize(input) + # output size calculation + # Pytorch assumes that one of out_shape/scale_factor is None + # Pytorch assumes that dimensions match for out_shape/scale factor + if out_shape is not None: + resize_layer.shape = list(input.shape)[:2] + list(out_shape) + elif scale_factors is not None: + resize_layer.scales = [1.0, 1.0] + list(scale_factors) + else: + raise RuntimeError( + f"At least one of out_shape and scale_factors should be specified." + ) + + # interpolate mode + if resize_mode == "nearest" or None: + resize_layer.resize_mode = trt.ResizeMode.NEAREST + elif resize_mode == "bilinear": + resize_layer.resize_mode = trt.ResizeMode.LINEAR + if align_corners is None or not align_corners: + raise RuntimeError( + f"Interpolation works differently is align_corners is False for {resize_mode} mode in PyTorch and TensorRT." + ) + else: + raise RuntimeError( + f"Interpolation mode is {resize_mode} which is not supported by TensorRT." + ) + + if resize_mode == "nearest": + resize_layer.coordinate_transformation = ( + trt.ResizeCoordinateTransformation.ASYMMETRIC + ) + elif resize_mode == "bilinear": + # align corners + if align_corners is not None and align_corners: + resize_layer.coordinate_transformation = ( + trt.ResizeCoordinateTransformation.ALIGN_CORNERS + ) + else: + resize_layer.coordinate_transformation = ( + trt.ResizeCoordinateTransformation.ASYMMETRIC + ) + + set_layer_name(resize_layer, target, name, source_ir) + + out = resize_layer.get_output(0) + return out diff --git a/tests/py/dynamo/conversion/test_upsample.py b/tests/py/dynamo/conversion/test_upsample.py new file mode 100644 index 0000000000..448b3afb84 --- /dev/null +++ b/tests/py/dynamo/conversion/test_upsample.py @@ -0,0 +1,97 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestUpsampleConverter(DispatchTestCase): + # test case for nearest upsample, using output_size, scale_factors is disabled here + @parameterized.expand( + [ + ("upsample_nearest2d.vec_outshape_0", (2, 2), (4, 4)), + ("upsample_nearest2d.vec_outshape_1", (2, 2), (5, 5)), + ] + ) + def test_upsample_nearest_output_shape(self, _, input_shape, output_shape): + class Upsample(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.upsample_nearest2d.vec(input, output_shape, None) + + input = [torch.randn([1, 1] + list(input_shape))] + self.run_test(Upsample(), input) + + # test case for nearest upsample, using scale_factors, output_size is disabled here + @parameterized.expand( + [ + ("upsample_nearest2d.vec_scale_0", (2, 2), (2, 2)), + ("upsample_nearest2d.vec_scale_1", (2, 2), (1.5, 1.5)), + ] + ) + def test_upsample_nearest_scale_factor(self, _, input_shape, scale_factor): + class Upsample(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.upsample_nearest2d.vec(input, None, scale_factor) + + input = [torch.randn([1, 1] + list(input_shape))] + self.run_test(Upsample(), input) + + # test case for bilinear upsample, using output_size, scale_factors is disabled here + @parameterized.expand( + [ + ("upsample_bilinear2d.vec_outshape_0", (2, 2), (4, 4), True), + ("upsample_bilinear2d.vec_outshape_1", (2, 2), (5, 5), True), + ] + ) + def test_upsample_bilinear_output_shape( + self, _, input_shape, output_shape, align_corners + ): + class Upsample(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.upsample_bilinear2d.vec( + input, + output_shape, + align_corners, + None, + ) + + input = [torch.randn([1, 1] + list(input_shape))] + self.run_test(Upsample(), input) + + # test case for bilinear upsample, using scale_factors, output_shape is disabled here + @parameterized.expand( + [ + ("upsample_bilinear2d.vec_scale_0", (2, 2), (2, 2), True), + ("upsample_bilinear2d.vec_scale_1", (2, 2), (1.5, 1.5), True), + ] + ) + def test_upsample_bilinear_scale_factors( + self, _, input_shape, scale_factors, align_corners + ): + class Upsample(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.aten.upsample_bilinear2d.vec( + input, + None, + align_corners, + scale_factors, + ) + + input = [torch.randn([1, 1] + list(input_shape))] + self.run_test(Upsample(), input) + + +if __name__ == "__main__": + run_tests()