From 94c7cb4cb54144ec10486ea53fc2df08363d1434 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 21 Nov 2023 17:52:03 -0800 Subject: [PATCH 01/11] feat: support constant padding dynamo converter --- .../dynamo/conversion/aten_ops_converters.py | 24 ++++++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/cat.py | 9 +-- .../dynamo/conversion/impl/pad.py | 76 +++++++++++++++++++ tests/py/dynamo/conversion/test_pad_aten.py | 35 +++++++++ 5 files changed, 140 insertions(+), 5 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/pad.py create mode 100644 tests/py/dynamo/conversion/test_pad_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index e578ec2fc7..2000835cf5 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2054,3 +2054,27 @@ def aten_ops_addmm( beta=kwargs.get("beta", 1), alpha=kwargs.get("alpha", 1), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.constant_pad_nd.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_constant_pad( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pad.constant_padNd( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args_bounds_check(args, 2, 0), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index b448b40bc3..62bf556beb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -16,6 +16,7 @@ linear, matmul, normalization, + pad, permutation, pool, reduce, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index 24149d01b0..f954a10fd7 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Sequence, Union +from typing import Optional, Sequence, Union import numpy as np import torch @@ -6,12 +6,11 @@ from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( - SourceIR, get_positive_dim, get_trt_tensor, ) from torch_tensorrt.fx.converters.converter_utils import set_layer_name -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.types import TRTTensor def cat( @@ -23,9 +22,9 @@ def cat( dim: int, ) -> Union[TRTTensor, Sequence[TRTTensor]]: trt_inputs = [] - for each_input in input: + for i, each_input in enumerate(input): if not isinstance(each_input, TRTTensor): - each_input = get_trt_tensor(ctx, each_input, name + "_tensor_{i}") + each_input = get_trt_tensor(ctx, each_input, name + f"_tensor_{i}") trt_inputs.append(each_input) concat_layer = ctx.net.add_concatenation(trt_inputs) dim = get_positive_dim(dim, len(input[0].shape)) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py new file mode 100644 index 0000000000..7c7accecd4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -0,0 +1,76 @@ +import copy +from typing import Optional, Sequence, Union + +import torch +import torch_tensorrt.dynamo.conversion.impl as impl +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.fx.converters.converter_utils import ( + has_dynamic_shape, + set_layer_name, +) +from torch_tensorrt.fx.types import TRTTensor + + +def constant_padNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + pad: Sequence[int], + value: int = 0, +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." + + pad_len = len(pad) + + if pad_len == 4 and value == 0: + pre_padding = (pad[2], pad[0]) + post_padding = (pad[3], pad[1]) + + # add padding layer + pad_layer = ctx.net.add_padding_nd( + input=input, + pre_padding=pre_padding, + post_padding=post_padding, + ) + + pad_layer.pre_padding_nd = pre_padding + pad_layer.post_padding_nd = post_padding + + set_layer_name(pad_layer, target, name, source_ir) + return pad_layer.get_output(0) + + else: + # Implement constant padding via concat + curr_dim = len(input.shape) - 1 + + for i in range(0, pad_len, 2): + input_shape = list(input.shape) + + pre_pad = pad[i] + post_pad = pad[i + 1] + pre_pad_shape = copy.deepcopy(input_shape) + pre_pad_shape[curr_dim] = pre_pad + pre_pad_tensor = torch.full(pre_pad_shape, float(value)) + if pre_pad == post_pad: + post_pad_tensor = pre_pad_tensor + else: + post_pad_shape = copy.deepcopy(input_shape) + post_pad_shape[curr_dim] = post_pad + post_pad_tensor = torch.full(post_pad_shape, float(value)) + output = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_concat{curr_dim}", + input=(pre_pad_tensor, input, post_pad_tensor), + dim=curr_dim, + ) + curr_dim -= 1 + input = output + + return output diff --git a/tests/py/dynamo/conversion/test_pad_aten.py b/tests/py/dynamo/conversion/test_pad_aten.py new file mode 100644 index 0000000000..993d276424 --- /dev/null +++ b/tests/py/dynamo/conversion/test_pad_aten.py @@ -0,0 +1,35 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestConstantPadConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1, 2), (1, 1), 0), + ((2, 1), (2, 1), 1), + ((3, 4, 2), (1, 2), 2), + ((3, 4, 2), (1, 2, 3, 1, 2, 3), 0), + ((3, 3, 4, 2), (1, 2, 3, 4), 0), + ((3, 3, 4, 2), (1, 2, 3, 4), 2), + ((3, 3, 4, 2, 1), (1, 2, 3, 4, 5, 1, 2, 3, 4, 5), 0), + ((3, 3, 4, 2, 1, 2), (1, 2, 3, 4, 1, 2, 3, 4), 4), + ] + ) + def test_constant_pad(self, shape, pad, value): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.constant_pad_nd.default(input, pad, value) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + +if __name__ == "__main__": + run_tests() From c63cbafd7e89514aeccce773b4c9192c1c3f76ab Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 22 Nov 2023 14:31:43 -0800 Subject: [PATCH 02/11] feat: support reflection padding dynamo converters for 1D, 2D, and 3D --- .../dynamo/conversion/aten_ops_converters.py | 25 ++++++++ .../dynamo/conversion/impl/pad.py | 59 +++++++++++++++++ .../dynamo/conversion/impl/slice/ops.py | 4 -- tests/py/dynamo/conversion/test_pad_aten.py | 63 ++++++++++++++++++- 4 files changed, 146 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 2000835cf5..97c87a67df 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2078,3 +2078,28 @@ def aten_ops_constant_pad( args[1], args_bounds_check(args, 2, 0), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad1d.default) +@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad2d.default) +@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad3d.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_reflection_pad( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pad.reflection_padNd( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index 7c7accecd4..6ac2dc0913 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -74,3 +74,62 @@ def constant_padNd( input = output return output + + +def reflection_padNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + padding: Sequence[int], +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." + + padding_dims = len(padding) // 2 + + if padding_dims == 1 or padding_dims == 2 or padding_dims == 3: + for i in range(padding_dims): + dim = -1 - i + pre_pad, post_pad = padding[2 * i], padding[2 * i + 1] + pre_pad_tensor = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_pre{i}", + input, + dim=dim, + start=pre_pad, + stop=0, + step=-1, + ) + + post_pad_tensor = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_post{i}", + input, + dim=dim, + start=input.shape[dim] - 2, + stop=input.shape[dim] - post_pad - 2, + step=-1, + ) + + output = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_concat_dim{dim}", + input=(pre_pad_tensor, input, post_pad_tensor), + dim=dim, + ) + input = output + + return output + + else: + raise RuntimeError( + f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D" + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 91ac4a7042..e8d329fb87 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -39,10 +39,6 @@ def slice_op( # TODO: This should be slice not whatever is in base if stop is None: stop = input.shape[dim] - dim = get_positive_dim(dim, len(input.shape)) - start = get_positive_dim(start, input.shape[dim]) - stop = get_positive_dim(stop, input.shape[dim]) - if has_dynamic_shape(input.shape): # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!" diff --git a/tests/py/dynamo/conversion/test_pad_aten.py b/tests/py/dynamo/conversion/test_pad_aten.py index 993d276424..abfd8a8e77 100644 --- a/tests/py/dynamo/conversion/test_pad_aten.py +++ b/tests/py/dynamo/conversion/test_pad_aten.py @@ -1,7 +1,6 @@ import torch from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input from .harness import DispatchTestCase @@ -31,5 +30,67 @@ def forward(self, input): ) +class TestReflectionPadConverter(DispatchTestCase): + @parameterized.expand( + [ + # Per pytorch doc, the input should be 2D or 3D + ((3, 3), (1, 1)), + ((3, 3), (2, 2)), + ((2, 2, 2), (1, 1)), + ((2, 2, 4), (2, 3)), + ] + ) + def test_reflection_pad1d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.reflection_pad1d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 3D or 4D + ((2, 2, 2), (1, 1, 1, 1)), + ((1, 2, 4), (2, 2, 1, 1)), + ((2, 2, 3, 3), (1, 1, 2, 2)), + ((2, 3, 4, 5), (4, 3, 0, 1)), + ] + ) + def test_reflection_pad2d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.reflection_pad2d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 4D or 5D + ((2, 2, 2, 2), (1, 1, 1, 1, 1, 1)), + ((1, 2, 3, 4), (3, 2, 2, 1, 1, 1)), + ((2, 2, 3, 3, 4), (3, 3, 2, 1, 1, 2)), + ((2, 3, 4, 5, 6), (4, 3, 2, 1, 1, 0)), + ] + ) + def test_reflection_pad3d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.reflection_pad3d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + if __name__ == "__main__": run_tests() From 215df670bef784638c00bc35a0a4830d7496bdec Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 22 Nov 2023 15:31:30 -0800 Subject: [PATCH 03/11] feat: support replication padding dynamo converters for 1D, 2D, and 3D --- .../dynamo/conversion/aten_ops_converters.py | 25 ++++++ .../dynamo/conversion/impl/pad.py | 77 +++++++++++++++++++ tests/py/dynamo/conversion/test_pad_aten.py | 62 +++++++++++++++ 3 files changed, 164 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 97c87a67df..5ced333e87 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2103,3 +2103,28 @@ def aten_ops_reflection_pad( args[0], args[1], ) + + +@dynamo_tensorrt_converter(torch.ops.aten.replication_pad1d.default) +@dynamo_tensorrt_converter(torch.ops.aten.replication_pad2d.default) +@dynamo_tensorrt_converter(torch.ops.aten.replication_pad3d.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_replication_pad( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pad.replication_padNd( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index 6ac2dc0913..47804a25ac 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -133,3 +133,80 @@ def reflection_padNd( raise RuntimeError( f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D" ) + + +def replication_padNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + padding: Sequence[int], +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." + + padding_dims = len(padding) // 2 + + if padding_dims == 1 or padding_dims == 2 or padding_dims == 3: + for i in range(padding_dims): + dim = -1 - i + pre_pad, post_pad = padding[2 * i], padding[2 * i + 1] + pre_pad_tensor = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_pre{i}", + input, + dim=dim, + start=0, + stop=1, + step=1, + ) + new_shape = input.shape + new_shape[dim] = pre_pad + pre_pad_tensor = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_pre{i}", + pre_pad_tensor, + new_shape, + ) + + post_pad_tensor = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_post{i}", + input, + dim=dim, + start=input.shape[dim] - 1, + stop=input.shape[dim], + step=1, + ) + new_shape[dim] = post_pad + post_pad_tensor = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_post{i}", + post_pad_tensor, + new_shape, + ) + output = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_concat_dim{dim}", + input=(pre_pad_tensor, input, post_pad_tensor), + dim=dim, + ) + input = output + + return output + + else: + raise RuntimeError( + f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D" + ) diff --git a/tests/py/dynamo/conversion/test_pad_aten.py b/tests/py/dynamo/conversion/test_pad_aten.py index abfd8a8e77..8dd6ae2a20 100644 --- a/tests/py/dynamo/conversion/test_pad_aten.py +++ b/tests/py/dynamo/conversion/test_pad_aten.py @@ -92,5 +92,67 @@ def forward(self, input): ) +class TestReplicationPadConverter(DispatchTestCase): + @parameterized.expand( + [ + # Per pytorch doc, the input should be 2D or 3D + ((3, 3), (1, 1)), + ((3, 3), (2, 2)), + ((2, 2, 2), (1, 1)), + ((2, 2, 4), (2, 3)), + ] + ) + def test_replication_pad1d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.replication_pad1d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 3D or 4D + ((2, 2, 2), (1, 1, 1, 1)), + ((1, 2, 4), (2, 2, 1, 1)), + ((2, 2, 3, 3), (1, 1, 2, 2)), + ((2, 3, 4, 5), (4, 3, 0, 1)), + ] + ) + def test_replication_pad2d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.replication_pad2d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 4D or 5D + ((2, 2, 2, 2), (1, 1, 1, 1, 1, 1)), + ((1, 2, 3, 4), (3, 2, 2, 1, 1, 1)), + ((2, 2, 3, 3, 4), (3, 3, 2, 1, 1, 2)), + ((2, 3, 4, 5, 6), (4, 3, 2, 1, 1, 0)), + ] + ) + def test_replication_pad3d(self, shape, padding): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.replication_pad3d.default(input, padding) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + if __name__ == "__main__": run_tests() From 7e0e4776403c60c773f5877c7191d36297a2164f Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 22 Nov 2023 16:44:49 -0800 Subject: [PATCH 04/11] feat: support circular padding dynamo converters for 1D, 2D, and 3D --- .../dynamo/conversion/aten_ops_converters.py | 23 +++++++ .../dynamo/conversion/impl/pad.py | 59 ++++++++++++++++++ tests/py/dynamo/conversion/test_pad_aten.py | 62 +++++++++++++++++++ 3 files changed, 144 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 5ced333e87..73d8c175e9 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2128,3 +2128,26 @@ def aten_ops_replication_pad( args[0], args[1], ) + + +@dynamo_tensorrt_converter(torch.ops.aten._pad_circular.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_circular_pad( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pad.circular_padNd( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index 47804a25ac..df874f06c6 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -210,3 +210,62 @@ def replication_padNd( raise RuntimeError( f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D" ) + + +def circular_padNd( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + pad: Sequence[int], +) -> TRTTensor: + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." + + padding_dims = len(pad) // 2 + + if padding_dims == 1 or padding_dims == 2 or padding_dims == 3: + for i in range(padding_dims): + dim = -1 - i + pre_pad, post_pad = pad[2 * i], pad[2 * i + 1] + pre_pad_tensor = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_pre{i}", + input, + dim=dim, + start=input.shape[dim] - pre_pad, + stop=input.shape[dim], + step=1, + ) + + post_pad_tensor = impl.slice.slice_op( + ctx, + target, + source_ir, + f"{name}_slice_post{i}", + input, + dim=dim, + start=0, + stop=post_pad, + step=1, + ) + + output = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_concat_dim{dim}", + input=(pre_pad_tensor, input, post_pad_tensor), + dim=dim, + ) + input = output + + return output + + else: + raise RuntimeError( + f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D" + ) diff --git a/tests/py/dynamo/conversion/test_pad_aten.py b/tests/py/dynamo/conversion/test_pad_aten.py index 8dd6ae2a20..4d25a10ef3 100644 --- a/tests/py/dynamo/conversion/test_pad_aten.py +++ b/tests/py/dynamo/conversion/test_pad_aten.py @@ -154,5 +154,67 @@ def forward(self, input): ) +class TestCircularPadConverter(DispatchTestCase): + @parameterized.expand( + [ + # Per pytorch doc, the input should be 2D or 3D + ((3, 3), (1, 1)), + ((3, 3), (2, 2)), + ((2, 2, 2), (1, 1)), + ((2, 2, 4), (2, 3)), + ] + ) + def test_circular_pad1d(self, shape, pad): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten._pad_circular.default(input, pad) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 3D or 4D + ((2, 2, 2), (1, 1, 1, 1)), + ((1, 2, 4), (2, 2, 1, 1)), + ((2, 2, 3, 3), (1, 1, 2, 2)), + ((2, 3, 4, 5), (4, 3, 0, 1)), + ] + ) + def test_circular_pad2d(self, shape, pad): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten._pad_circular.default(input, pad) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + @parameterized.expand( + [ + # Per pytorch doc, the input should be 4D or 5D + ((2, 2, 2, 2), (1, 1, 1, 1, 1, 1)), + ((1, 2, 3, 4), (3, 2, 2, 1, 1, 1)), + ((2, 2, 3, 3, 4), (3, 3, 2, 1, 1, 2)), + ((2, 3, 4, 5, 6), (4, 3, 2, 1, 1, 0)), + ] + ) + def test_circular_pad3d(self, shape, pad): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten._pad_circular.default(input, pad) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + if __name__ == "__main__": run_tests() From 78c2f43e43bf40668b1a3265a31006d62858cb96 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 22 Nov 2023 16:50:55 -0800 Subject: [PATCH 05/11] feat: support pad dynamo converter --- .../dynamo/conversion/aten_ops_converters.py | 25 ++++++++++++++ .../dynamo/conversion/impl/pad.py | 34 ++++++++++++++++++- tests/py/dynamo/conversion/test_pad_aten.py | 21 ++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 73d8c175e9..fb031c4f19 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2151,3 +2151,28 @@ def aten_ops_circular_pad( args[0], args[1], ) + + +@dynamo_tensorrt_converter(torch.ops.aten.pad.default) +@enforce_tensor_types( + { + 0: (TRTTensor,), + } +) +def aten_ops_pad( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.pad.pad( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + pad=args[1], + mode=args_bounds_check(args, 2, "constant"), + value=args_bounds_check(args, 3, None), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index df874f06c6..7c0741d483 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -20,7 +20,7 @@ def constant_padNd( name: str, input: TRTTensor, pad: Sequence[int], - value: int = 0, + value: Union[int, float] = 0, ) -> TRTTensor: if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." @@ -269,3 +269,35 @@ def circular_padNd( raise RuntimeError( f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D" ) + + +def pad( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + pad: Sequence[int], + mode: str = "constant", + value: Optional[float] = None, +) -> TRTTensor: + if mode == "constant": + return constant_padNd( + ctx, + target, + source_ir, + f"{name}_{mode}", + input, + pad, + value if value is not None else 0, + ) + elif mode == "reflect": + return reflection_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad) + elif mode == "replicate": + return replication_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad) + elif mode == "circular": + return circular_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad) + else: + raise RuntimeError( + f'We currently only support for `mode` in ["constant", "reflect", "replicate", "circular"], but got {mode}' + ) diff --git a/tests/py/dynamo/conversion/test_pad_aten.py b/tests/py/dynamo/conversion/test_pad_aten.py index 4d25a10ef3..2803736ad0 100644 --- a/tests/py/dynamo/conversion/test_pad_aten.py +++ b/tests/py/dynamo/conversion/test_pad_aten.py @@ -216,5 +216,26 @@ def forward(self, input): ) +class TestPadConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 3), (2, 2), "constant"), + ((2, 2, 4), (2, 3, 1, 0), "reflect"), + ((1, 2, 3, 4), (3, 2, 2, 1, 1, 1), "replicate"), + ((2, 3, 4, 5), (3, 2, 1, 0), "circular"), + ] + ) + def test_pad(self, shape, pad, mode, value=None): + class TestModule(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.pad.default(input, pad, mode, value) + + input = [torch.randn(shape)] + self.run_test( + TestModule(), + input, + ) + + if __name__ == "__main__": run_tests() From ccc5d3c42b0f851fef1981da4554da2e562094c5 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 22 Nov 2023 17:12:39 -0800 Subject: [PATCH 06/11] fix a concat bug --- py/torch_tensorrt/dynamo/conversion/impl/cat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cat.py b/py/torch_tensorrt/dynamo/conversion/impl/cat.py index f954a10fd7..d6ffc77377 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cat.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cat.py @@ -24,10 +24,10 @@ def cat( trt_inputs = [] for i, each_input in enumerate(input): if not isinstance(each_input, TRTTensor): - each_input = get_trt_tensor(ctx, each_input, name + f"_tensor_{i}") + each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}") trt_inputs.append(each_input) concat_layer = ctx.net.add_concatenation(trt_inputs) dim = get_positive_dim(dim, len(input[0].shape)) concat_layer.axis = dim - set_layer_name(concat_layer, target, name + "_gather", source_ir) + set_layer_name(concat_layer, target, f"{name}_gather", source_ir) return concat_layer.get_output(0) From d908601761616ce404266d2a023052ddb7ba70de Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 22 Nov 2023 17:34:22 -0800 Subject: [PATCH 07/11] update constant pad --- .../dynamo/conversion/impl/pad.py | 84 ++++++++----------- 1 file changed, 33 insertions(+), 51 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index 7c0741d483..21097af39d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -6,10 +6,7 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.fx.converters.converter_utils import ( - has_dynamic_shape, - set_layer_name, -) +from torch_tensorrt.fx.converters.converter_utils import has_dynamic_shape from torch_tensorrt.fx.types import TRTTensor @@ -22,58 +19,43 @@ def constant_padNd( pad: Sequence[int], value: Union[int, float] = 0, ) -> TRTTensor: + """ + Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0. + Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding + mode and clamp, and supports padding output with dynamic shape. + """ if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." - pad_len = len(pad) - - if pad_len == 4 and value == 0: - pre_padding = (pad[2], pad[0]) - post_padding = (pad[3], pad[1]) - - # add padding layer - pad_layer = ctx.net.add_padding_nd( - input=input, - pre_padding=pre_padding, - post_padding=post_padding, + # Implement constant padding via concat + curr_dim = len(input.shape) - 1 + + for i in range(0, len(pad), 2): + input_shape = list(input.shape) + + pre_pad = pad[i] + post_pad = pad[i + 1] + pre_pad_shape = copy.deepcopy(input_shape) + pre_pad_shape[curr_dim] = pre_pad + pre_pad_tensor = torch.full(pre_pad_shape, float(value)) + if pre_pad == post_pad: + post_pad_tensor = pre_pad_tensor + else: + post_pad_shape = copy.deepcopy(input_shape) + post_pad_shape[curr_dim] = post_pad + post_pad_tensor = torch.full(post_pad_shape, float(value)) + output = impl.cat.cat( + ctx, + target, + source_ir, + f"{name}_concat{curr_dim}", + input=(pre_pad_tensor, input, post_pad_tensor), + dim=curr_dim, ) + curr_dim -= 1 + input = output - pad_layer.pre_padding_nd = pre_padding - pad_layer.post_padding_nd = post_padding - - set_layer_name(pad_layer, target, name, source_ir) - return pad_layer.get_output(0) - - else: - # Implement constant padding via concat - curr_dim = len(input.shape) - 1 - - for i in range(0, pad_len, 2): - input_shape = list(input.shape) - - pre_pad = pad[i] - post_pad = pad[i + 1] - pre_pad_shape = copy.deepcopy(input_shape) - pre_pad_shape[curr_dim] = pre_pad - pre_pad_tensor = torch.full(pre_pad_shape, float(value)) - if pre_pad == post_pad: - post_pad_tensor = pre_pad_tensor - else: - post_pad_shape = copy.deepcopy(input_shape) - post_pad_shape[curr_dim] = post_pad - post_pad_tensor = torch.full(post_pad_shape, float(value)) - output = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_concat{curr_dim}", - input=(pre_pad_tensor, input, post_pad_tensor), - dim=curr_dim, - ) - curr_dim -= 1 - input = output - - return output + return output def reflection_padNd( From 7534fa21e2d8121f3a65fba6b1e693974d62d3a0 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 27 Nov 2023 14:43:22 -0800 Subject: [PATCH 08/11] implement paddings via TRT ISliceLayer with different SliceMode --- .../dynamo/conversion/impl/pad.py | 286 +++++++----------- .../dynamo/conversion/impl/slice/ops.py | 4 + 2 files changed, 107 insertions(+), 183 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index 21097af39d..38a14cb8df 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -1,14 +1,22 @@ -import copy from typing import Optional, Sequence, Union -import torch -import torch_tensorrt.dynamo.conversion.impl as impl +import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.fx.converters.converter_utils import has_dynamic_shape +from torch_tensorrt.fx.converters.converter_utils import ( + get_trt_tensor, + has_dynamic_shape, + set_layer_name, +) from torch_tensorrt.fx.types import TRTTensor +""" +Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0. +Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding +mode and clamp, and supports padding output with dynamic shape. +""" + def constant_padNd( ctx: ConversionContext, @@ -19,43 +27,36 @@ def constant_padNd( pad: Sequence[int], value: Union[int, float] = 0, ) -> TRTTensor: - """ - Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0. - Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding - mode and clamp, and supports padding output with dynamic shape. - """ if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." - # Implement constant padding via concat - curr_dim = len(input.shape) - 1 - - for i in range(0, len(pad), 2): - input_shape = list(input.shape) - - pre_pad = pad[i] - post_pad = pad[i + 1] - pre_pad_shape = copy.deepcopy(input_shape) - pre_pad_shape[curr_dim] = pre_pad - pre_pad_tensor = torch.full(pre_pad_shape, float(value)) - if pre_pad == post_pad: - post_pad_tensor = pre_pad_tensor - else: - post_pad_shape = copy.deepcopy(input_shape) - post_pad_shape[curr_dim] = post_pad - post_pad_tensor = torch.full(post_pad_shape, float(value)) - output = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_concat{curr_dim}", - input=(pre_pad_tensor, input, post_pad_tensor), - dim=curr_dim, + rank = len(input.shape) + + if len(pad) / 2 > rank: + raise RuntimeError( + f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension." ) - curr_dim -= 1 - input = output - return output + start_list = [0] * len(input.shape) + new_shape = input.shape + + for i in range(0, len(pad) // 2): + start_list[-i - 1] = -pad[i * 2] + new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1] + + stride_list = [1] * len(new_shape) + layer = ctx.net.add_slice( + input, + start=tuple(start_list), + shape=tuple(new_shape), + stride=tuple(stride_list), + ) + value_const = get_trt_tensor(ctx.net, value, f"{name}_value", input.dtype) + layer.set_input(4, value_const) + layer.mode = trt.SliceMode.FILL + + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) def reflection_padNd( @@ -69,53 +70,32 @@ def reflection_padNd( if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." - padding_dims = len(padding) // 2 - - if padding_dims == 1 or padding_dims == 2 or padding_dims == 3: - for i in range(padding_dims): - dim = -1 - i - pre_pad, post_pad = padding[2 * i], padding[2 * i + 1] - pre_pad_tensor = impl.slice.slice_op( - ctx, - target, - source_ir, - f"{name}_slice_pre{i}", - input, - dim=dim, - start=pre_pad, - stop=0, - step=-1, - ) - - post_pad_tensor = impl.slice.slice_op( - ctx, - target, - source_ir, - f"{name}_slice_post{i}", - input, - dim=dim, - start=input.shape[dim] - 2, - stop=input.shape[dim] - post_pad - 2, - step=-1, - ) - - output = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_concat_dim{dim}", - input=(pre_pad_tensor, input, post_pad_tensor), - dim=dim, - ) - input = output - - return output + rank = len(input.shape) - else: + if len(padding) / 2 > rank: raise RuntimeError( - f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D" + f"Trying to pad last {len(padding) / 2} dimension but the input only has {rank} dimension." ) + start_list = [0] * len(input.shape) + new_shape = input.shape + + for i in range(0, len(padding) // 2): + start_list[-i - 1] = -padding[i * 2] + new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1] + + stride_list = [1] * len(new_shape) + layer = ctx.net.add_slice( + input, + start=tuple(start_list), + shape=tuple(new_shape), + stride=tuple(stride_list), + ) + layer.mode = trt.SliceMode.REFLECT + + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + def replication_padNd( ctx: ConversionContext, @@ -128,71 +108,32 @@ def replication_padNd( if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." - padding_dims = len(padding) // 2 - - if padding_dims == 1 or padding_dims == 2 or padding_dims == 3: - for i in range(padding_dims): - dim = -1 - i - pre_pad, post_pad = padding[2 * i], padding[2 * i + 1] - pre_pad_tensor = impl.slice.slice_op( - ctx, - target, - source_ir, - f"{name}_slice_pre{i}", - input, - dim=dim, - start=0, - stop=1, - step=1, - ) - new_shape = input.shape - new_shape[dim] = pre_pad - pre_pad_tensor = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_pre{i}", - pre_pad_tensor, - new_shape, - ) - - post_pad_tensor = impl.slice.slice_op( - ctx, - target, - source_ir, - f"{name}_slice_post{i}", - input, - dim=dim, - start=input.shape[dim] - 1, - stop=input.shape[dim], - step=1, - ) - new_shape[dim] = post_pad - post_pad_tensor = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_post{i}", - post_pad_tensor, - new_shape, - ) - output = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_concat_dim{dim}", - input=(pre_pad_tensor, input, post_pad_tensor), - dim=dim, - ) - input = output - - return output + rank = len(input.shape) - else: + if len(padding) / 2 > rank: raise RuntimeError( - f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D" + f"Trying to pad last {len(padding) / 2} dimension but the input only has {rank} dimension." ) + start_list = [0] * len(input.shape) + new_shape = input.shape + + for i in range(0, len(padding) // 2): + start_list[-i - 1] = -padding[i * 2] + new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1] + + stride_list = [1] * len(new_shape) + layer = ctx.net.add_slice( + input, + start=tuple(start_list), + shape=tuple(new_shape), + stride=tuple(stride_list), + ) + layer.mode = trt.SliceMode.CLAMP + + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + def circular_padNd( ctx: ConversionContext, @@ -205,53 +146,32 @@ def circular_padNd( if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for padding." - padding_dims = len(pad) // 2 - - if padding_dims == 1 or padding_dims == 2 or padding_dims == 3: - for i in range(padding_dims): - dim = -1 - i - pre_pad, post_pad = pad[2 * i], pad[2 * i + 1] - pre_pad_tensor = impl.slice.slice_op( - ctx, - target, - source_ir, - f"{name}_slice_pre{i}", - input, - dim=dim, - start=input.shape[dim] - pre_pad, - stop=input.shape[dim], - step=1, - ) - - post_pad_tensor = impl.slice.slice_op( - ctx, - target, - source_ir, - f"{name}_slice_post{i}", - input, - dim=dim, - start=0, - stop=post_pad, - step=1, - ) - - output = impl.cat.cat( - ctx, - target, - source_ir, - f"{name}_concat_dim{dim}", - input=(pre_pad_tensor, input, post_pad_tensor), - dim=dim, - ) - input = output - - return output + rank = len(input.shape) - else: + if len(pad) / 2 > rank: raise RuntimeError( - f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D" + f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension." ) + start_list = [0] * len(input.shape) + new_shape = input.shape + + for i in range(0, len(pad) // 2): + start_list[-i - 1] = -pad[i * 2] + new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1] + + stride_list = [1] * len(new_shape) + layer = ctx.net.add_slice( + input, + start=tuple(start_list), + shape=tuple(new_shape), + stride=tuple(stride_list), + ) + layer.mode = trt.SliceMode.WRAP + + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + def pad( ctx: ConversionContext, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index e8d329fb87..91ac4a7042 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -39,6 +39,10 @@ def slice_op( # TODO: This should be slice not whatever is in base if stop is None: stop = input.shape[dim] + dim = get_positive_dim(dim, len(input.shape)) + start = get_positive_dim(start, input.shape[dim]) + stop = get_positive_dim(stop, input.shape[dim]) + if has_dynamic_shape(input.shape): # Check whether slice target dim is dynamic shape dim assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!" From 1b21c9bf2c30758dc91b3ab01127a108c71643d0 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 27 Nov 2023 17:51:28 -0800 Subject: [PATCH 09/11] fix import bug --- py/torch_tensorrt/dynamo/conversion/impl/pad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index 38a14cb8df..fca4133387 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -4,8 +4,8 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor from torch_tensorrt.fx.converters.converter_utils import ( - get_trt_tensor, has_dynamic_shape, set_layer_name, ) From 2e8c094a6cb0d5f1605ee303f72779bf76a7a630 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 28 Nov 2023 15:16:57 -0800 Subject: [PATCH 10/11] fix bugs --- .../dynamo/conversion/impl/pad.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index fca4133387..c9dd58359a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -32,9 +32,9 @@ def constant_padNd( rank = len(input.shape) - if len(pad) / 2 > rank: + if len(pad) // 2 > rank: raise RuntimeError( - f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension." + f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension." ) start_list = [0] * len(input.shape) @@ -51,7 +51,7 @@ def constant_padNd( shape=tuple(new_shape), stride=tuple(stride_list), ) - value_const = get_trt_tensor(ctx.net, value, f"{name}_value", input.dtype) + value_const = get_trt_tensor(ctx, value, f"{name}_value", input.dtype) layer.set_input(4, value_const) layer.mode = trt.SliceMode.FILL @@ -72,9 +72,9 @@ def reflection_padNd( rank = len(input.shape) - if len(padding) / 2 > rank: + if len(padding) // 2 > rank: raise RuntimeError( - f"Trying to pad last {len(padding) / 2} dimension but the input only has {rank} dimension." + f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension." ) start_list = [0] * len(input.shape) @@ -110,9 +110,9 @@ def replication_padNd( rank = len(input.shape) - if len(padding) / 2 > rank: + if len(padding) // 2 > rank: raise RuntimeError( - f"Trying to pad last {len(padding) / 2} dimension but the input only has {rank} dimension." + f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension." ) start_list = [0] * len(input.shape) @@ -148,9 +148,9 @@ def circular_padNd( rank = len(input.shape) - if len(pad) / 2 > rank: + if len(pad) // 2 > rank: raise RuntimeError( - f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension." + f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension." ) start_list = [0] * len(input.shape) From a682814a05b42ed2763d5d98ecd88d75d9af287c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 29 Nov 2023 14:18:29 -0800 Subject: [PATCH 11/11] add some small modifications --- .../dynamo/conversion/impl/pad.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pad.py b/py/torch_tensorrt/dynamo/conversion/impl/pad.py index c9dd58359a..3764667ffb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pad.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pad.py @@ -37,14 +37,14 @@ def constant_padNd( f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension." ) - start_list = [0] * len(input.shape) - new_shape = input.shape + start_list = [0] * rank + new_shape = list(input.shape) for i in range(0, len(pad) // 2): start_list[-i - 1] = -pad[i * 2] new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1] - stride_list = [1] * len(new_shape) + stride_list = [1] * rank layer = ctx.net.add_slice( input, start=tuple(start_list), @@ -77,14 +77,14 @@ def reflection_padNd( f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension." ) - start_list = [0] * len(input.shape) - new_shape = input.shape + start_list = [0] * rank + new_shape = list(input.shape) for i in range(0, len(padding) // 2): start_list[-i - 1] = -padding[i * 2] new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1] - stride_list = [1] * len(new_shape) + stride_list = [1] * rank layer = ctx.net.add_slice( input, start=tuple(start_list), @@ -115,14 +115,14 @@ def replication_padNd( f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension." ) - start_list = [0] * len(input.shape) - new_shape = input.shape + start_list = [0] * rank + new_shape = list(input.shape) for i in range(0, len(padding) // 2): start_list[-i - 1] = -padding[i * 2] new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1] - stride_list = [1] * len(new_shape) + stride_list = [1] * rank layer = ctx.net.add_slice( input, start=tuple(start_list), @@ -153,14 +153,14 @@ def circular_padNd( f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension." ) - start_list = [0] * len(input.shape) - new_shape = input.shape + start_list = [0] * rank + new_shape = list(input.shape) for i in range(0, len(pad) // 2): start_list[-i - 1] = -pad[i * 2] new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1] - stride_list = [1] * len(new_shape) + stride_list = [1] * rank layer = ctx.net.add_slice( input, start=tuple(start_list),