diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 663f7436db..3372af067c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1085,7 +1085,7 @@ def aten_ops_expand( ) -@dynamo_tensorrt_converter(torch.ops.aten.amax.default) +@dynamo_tensorrt_converter(torch.ops.aten.amax.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -1109,7 +1109,7 @@ def aten_ops_amax( ) -@dynamo_tensorrt_converter(torch.ops.aten.amin.default) +@dynamo_tensorrt_converter(torch.ops.aten.amin.default, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -1133,9 +1133,9 @@ def aten_ops_amin( ) -@dynamo_tensorrt_converter(torch.ops.aten.sum.default) -@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList) -@dynamo_tensorrt_converter(torch.ops.prims.sum.default) +@dynamo_tensorrt_converter(torch.ops.aten.sum.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.prims.sum.default, supports_dynamic_shapes=True) def aten_ops_sum( ctx: ConversionContext, target: Target, @@ -1167,8 +1167,8 @@ def aten_ops_sum( return sum_ -@dynamo_tensorrt_converter(torch.ops.aten.prod.default) -@dynamo_tensorrt_converter(torch.ops.aten.prod.dim_int) +@dynamo_tensorrt_converter(torch.ops.aten.prod.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.prod.dim_int, supports_dynamic_shapes=True) def aten_ops_prod( ctx: ConversionContext, target: Target, @@ -1187,9 +1187,14 @@ def aten_ops_prod( ) -@dynamo_tensorrt_converter(torch.ops.aten.max.default) @dynamo_tensorrt_converter( - torch.ops.aten.max.dim, capability_validator=one_user_validator + torch.ops.aten.max.default, + supports_dynamic_shapes=True, +) +@dynamo_tensorrt_converter( + torch.ops.aten.max.dim, + capability_validator=one_user_validator, + supports_dynamic_shapes=True, ) def aten_ops_max( ctx: ConversionContext, @@ -1210,9 +1215,14 @@ def aten_ops_max( ) -@dynamo_tensorrt_converter(torch.ops.aten.min.default) @dynamo_tensorrt_converter( - torch.ops.aten.min.dim, capability_validator=one_user_validator + torch.ops.aten.min.default, + supports_dynamic_shapes=True, +) +@dynamo_tensorrt_converter( + torch.ops.aten.min.dim, + capability_validator=one_user_validator, + supports_dynamic_shapes=True, ) def aten_ops_min( ctx: ConversionContext, diff --git a/tests/py/dynamo/conversion/test_amax_aten.py b/tests/py/dynamo/conversion/test_amax_aten.py index bdb0db5779..6da06d953a 100644 --- a/tests/py/dynamo/conversion/test_amax_aten.py +++ b/tests/py/dynamo/conversion/test_amax_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -90,6 +91,38 @@ def forward(self, x): check_dtype=False, ) + @parameterized.expand( + [ + ((0, 1), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + ((0,), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + (1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + (2, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + (-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + ((-1, 0), True, (2, 2, 5), (3, 3, 6), (4, 5, 7)), + ] + ) + def test_amax_dynamic_shape(self, dim, keep_dim, min_shape, opt_shape, max_shape): + class Amax(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.ops.aten.amax.default(x, dim, keep_dim) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + Amax(dim), + input_specs, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_amin_aten.py b/tests/py/dynamo/conversion/test_amin_aten.py index 03ae9b6113..4ab68a7466 100644 --- a/tests/py/dynamo/conversion/test_amin_aten.py +++ b/tests/py/dynamo/conversion/test_amin_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -90,6 +91,38 @@ def forward(self, x): check_dtype=False, ) + @parameterized.expand( + [ + ((0, 1), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + ((0,), False, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + (1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + (2, False, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + (-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + ((-1, 0), True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + ] + ) + def test_amin_dynamic_shape(self, dim, keep_dim, min_shape, opt_shape, max_shape): + class Amin(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.ops.aten.amin.default(x, dim, keep_dim) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + Amin(dim), + input_specs, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_max_aten.py b/tests/py/dynamo/conversion/test_max_aten.py index 7839bc4113..7652a7f8c8 100644 --- a/tests/py/dynamo/conversion/test_max_aten.py +++ b/tests/py/dynamo/conversion/test_max_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -65,6 +66,62 @@ def forward(self, x): check_dtype=False, ) + @parameterized.expand( + [ + (1, True, (2, 2, 3), (2, 3, 3), (3, 3, 4)), + (2, False, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + (-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + ] + ) + def test_max_dim_dynamic_shape( + self, dim, keep_dim, min_shape, opt_shape, max_shape + ): + class Max(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.ops.aten.max.dim(x, dim, keep_dim)[0] + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + Max(dim), + input_specs, + ) + + @parameterized.expand( + [ + ((2, 2, 3), (2, 3, 3), (3, 3, 4)), + ((2, 3, 5), (3, 4, 6), (4, 5, 7)), + ((2, 3, 5), (3, 4, 6), (4, 5, 7)), + ] + ) + def test_max_default_dynamic_shape(self, min_shape, opt_shape, max_shape): + class Max(nn.Module): + def forward(self, x): + return torch.ops.aten.max.default(x) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + Max(), + input_specs, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_min_aten.py b/tests/py/dynamo/conversion/test_min_aten.py index 3d0cd29923..17e034d3f4 100644 --- a/tests/py/dynamo/conversion/test_min_aten.py +++ b/tests/py/dynamo/conversion/test_min_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -65,6 +66,62 @@ def forward(self, x): check_dtype=False, ) + @parameterized.expand( + [ + (1, True, (2, 2, 3), (2, 3, 3), (3, 3, 4)), + (2, False, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + (-1, True, (2, 3, 5), (3, 4, 6), (4, 5, 7)), + ] + ) + def test_min_dim_dynamic_shape( + self, dim, keep_dim, min_shape, opt_shape, max_shape + ): + class Min(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.ops.aten.min.dim(x, dim, keep_dim)[0] + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + Min(dim), + input_specs, + ) + + @parameterized.expand( + [ + ((2, 2, 3), (2, 3, 3), (3, 3, 4)), + ((2, 3, 5), (3, 4, 6), (4, 5, 7)), + ((2, 3, 5), (3, 4, 6), (4, 5, 7)), + ] + ) + def test_min_default_dynamic_shape(self, min_shape, opt_shape, max_shape): + class Min(nn.Module): + def forward(self, x): + return torch.ops.aten.min.default(x) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + Min(), + input_specs, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_prod_aten.py b/tests/py/dynamo/conversion/test_prod_aten.py index 3fbb602098..ee6b91e6f8 100644 --- a/tests/py/dynamo/conversion/test_prod_aten.py +++ b/tests/py/dynamo/conversion/test_prod_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -68,6 +69,33 @@ def forward(self, x): use_dynamo_tracer=True, ) + @parameterized.expand( + [ + (0, (2, 3), (2, 4), (3, 5)), + (1, (2, 3), (2, 4), (3, 5)), + (2, (2, 2, 4), (2, 3, 4), (3, 4, 5)), + (-1, (2, 2, 4), (2, 3, 4), (3, 4, 5)), + ] + ) + def test_prod_dynamic_shape(self, dim, min_shape, opt_shape, max_shape): + class Prod(nn.Module): + def forward(self, x): + return torch.prod(x, dim) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + Prod(), + input_specs, + use_dynamo_tracer=True, + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/conversion/test_sum_aten.py b/tests/py/dynamo/conversion/test_sum_aten.py index bac8c7edf1..ad8ad6fec1 100644 --- a/tests/py/dynamo/conversion/test_sum_aten.py +++ b/tests/py/dynamo/conversion/test_sum_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -130,6 +131,39 @@ def forward(self, x): inputs, ) + @parameterized.expand( + [ + ([0], (2, 3), (2, 4), (3, 5)), + ([1], (2, 3), (2, 4), (3, 5)), + ( + [ + 2, + ], + (2, 2, 4), + (2, 3, 4), + (3, 4, 5), + ), + ([0, 1], (2, 2, 4), (2, 3, 4), (3, 4, 5)), + ] + ) + def test_sum_dynamic_shape(self, dim, min_shape, opt_shape, max_shape): + class Sum(nn.Module): + def forward(self, x): + return torch.ops.prims.sum.default(x, dim) + + input_specs = [ + Input( + dtype=torch.float32, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + ), + ] + self.run_test_with_dynamic_shape( + Sum(), + input_specs, + ) + if __name__ == "__main__": run_tests()