From 8cf6492bed2bbb2609d2a2b651726f7cad1c846a Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 17 Aug 2023 16:56:28 -0700 Subject: [PATCH 1/7] feat: support amax dynamo converter --- .../dynamo/conversion/aten_ops_converters.py | 25 +++++ .../dynamo/conversion/converter_utils.py | 5 + .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/reduce.py | 39 ++++++++ tests/py/dynamo/converters/test_amax_aten.py | 93 +++++++++++++++++++ 5 files changed, 163 insertions(+) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/reduce.py create mode 100644 tests/py/dynamo/converters/test_amax_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 75a7782354..c0d1fb08d5 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -420,3 +420,28 @@ def aten_ops_clone( name, args[0], ) + + +@dynamo_tensorrt_converter(torch.ops.aten.amax.default) +def aten_ops_amax( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + input_val = args[0] + if (isinstance(input_val, TRTTensor)) and ( + input_val.dtype == trt.int8 or input_val.dtype == trt.int32 + ): + input_val = cast_trt_tensor(network, input_val, trt.float32, name) + + return impl.reduce.amax( + network, + target, + SourceIR.ATEN, + name, + input_val, + args[1], + args_bounds_check(args, 2, replacement=False), + ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index ed0f1bb843..bb4466a21e 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -8,11 +8,13 @@ from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, unified_dtype_converter, + get_axes_for_reduce_op, ) from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor from .._SourceIR import SourceIR from .converter_registry import ConverterRegistry +import functools _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -157,3 +159,6 @@ def broadcastable( if not (a_shape[i] == b_shape[i] or a_shape[i] == 1 or b_shape[i] == 1): return False return True + + +get_axes_for_reduce_op = functools.partial(get_axes_for_reduce_op, has_implicit_batch_dimension=False) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 8f7ab1badc..6bd315871c 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -9,6 +9,7 @@ matmul, normalization, permutation, + reduce, select, shape, slice, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py new file mode 100644 index 0000000000..54ad1d3047 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -0,0 +1,39 @@ +from typing import Optional, Union, cast, Any, Tuple + +import tensorrt as trt + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + set_layer_name, +) +from torch_tensorrt.dynamo.conversion.converter_utils import get_axes_for_reduce_op + + +def amax( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Union[int, Tuple[int]], + keep_dims: Optional[bool] = False, + out: Optional[Any] = None +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError(f"amax received input {input} that is not part of the TensorRT region!" + ) + + if dim is None: + raise ValueError("amax requires specifying dimension(s) (dim).") + + layer = network.add_reduce( + input, + trt.ReduceOperation.MAX, + axes=get_axes_for_reduce_op(dim), + keep_dims=keep_dims + ) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/tests/py/dynamo/converters/test_amax_aten.py b/tests/py/dynamo/converters/test_amax_aten.py new file mode 100644 index 0000000000..5923e6d40e --- /dev/null +++ b/tests/py/dynamo/converters/test_amax_aten.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestAmaxConverter(DispatchTestCase): + @parameterized.expand( + [ + ((3, 2, 4), 1, True), + ((2, 3, 4, 5), 3, True), + ((2, 3, 4, 5), 2, False), + ((6, 7, 5, 4, 5), 4, False), + ] + ) + def test_amax_dim_int_int(self, input_shape, dim, keep_dims, dtype): + class Amax(nn.Module): + def forward(self, x): + return torch.amax(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randn(*input_shape, dtype=dtype)] + self.run_test( + Amax(), + inputs, + expected_ops={torch.ops.aten.amax.default}, + ) + + @parameterized.expand( + [ + ((3, 2, 4), [1], True), + ((2, 1, 4, 5), [0, 3], True), + ((2, 3, 4, 5), [0, 1, 2, 3], False), + ((6, 7, 5, 4, 5), [1, 3, 4], False), + ] + ) + def test_amax_dim_tuple_int(self, input_shape, dim, keep_dims, dtype): + class Amax(nn.Module): + def forward(self, x): + return torch.amax(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randn(*input_shape, dtype=dtype)] + self.run_test( + Amax(), + inputs, + expected_ops={torch.ops.aten.amax.default}, + ) + + @parameterized.expand( + [ + ((3, 2, 4), 1, True, torch.int, 0, 5), + ((2, 3, 4, 5), 3, True, torch.int, -10, 10), + ((2, 3, 4, 5), 2, False, torch.int32, -5, 0), + ((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5), + ] + ) + def test_amax_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high): + class Amax(nn.Module): + def forward(self, x): + return torch.amax(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + Amax(), + inputs, + expected_ops={torch.ops.aten.amax.default}, + check_dtype=False, + ) + + @parameterized.expand( + [ + ((3, 2, 4), [1], True, torch.int, 0, 5), + ((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10), + ((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0), + ((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5), + ] + ) + def test_amax_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high): + class Amax(nn.Module): + def forward(self, x): + return torch.amax(x, dim=dim, keepdim=keep_dims) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + Amax(), + inputs, + expected_ops={torch.ops.aten.amax.default}, + check_dtype=False, + ) + + +if __name__ == "__main__": + run_tests() From e73ceaea92e82fa6f4585d144a4028a8b71ee292 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 18 Aug 2023 14:56:03 -0700 Subject: [PATCH 2/7] lint code --- .../dynamo/conversion/converter_utils.py | 8 ++++--- .../dynamo/conversion/impl/reduce.py | 23 ++++++++----------- tests/py/dynamo/converters/test_amax_aten.py | 4 ++-- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index bb4466a21e..e33bf09903 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,3 +1,4 @@ +import functools import logging import re from typing import List, Optional @@ -7,14 +8,13 @@ from torch.fx.node import Target from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, - unified_dtype_converter, get_axes_for_reduce_op, + unified_dtype_converter, ) from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor from .._SourceIR import SourceIR from .converter_registry import ConverterRegistry -import functools _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -161,4 +161,6 @@ def broadcastable( return True -get_axes_for_reduce_op = functools.partial(get_axes_for_reduce_op, has_implicit_batch_dimension=False) +get_axes_for_reduce_op = functools.partial( + get_axes_for_reduce_op, has_implicit_batch_dimension=False +) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py index 54ad1d3047..1ac173794d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -1,15 +1,11 @@ -from typing import Optional, Union, cast, Any, Tuple +from typing import Any, Optional, Tuple, Union import tensorrt as trt - from torch.fx.node import Target - -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.converters.converter_utils import ( - set_layer_name, -) from torch_tensorrt.dynamo.conversion.converter_utils import get_axes_for_reduce_op +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor def amax( @@ -20,20 +16,21 @@ def amax( input: TRTTensor, dim: Union[int, Tuple[int]], keep_dims: Optional[bool] = False, - out: Optional[Any] = None + out: Optional[Any] = None, ) -> TRTTensor: if not isinstance(input, TRTTensor): - raise RuntimeError(f"amax received input {input} that is not part of the TensorRT region!" - ) + raise RuntimeError( + f"amax received input {input} that is not part of the TensorRT region!" + ) if dim is None: raise ValueError("amax requires specifying dimension(s) (dim).") layer = network.add_reduce( - input, - trt.ReduceOperation.MAX, + input, + trt.ReduceOperation.MAX, axes=get_axes_for_reduce_op(dim), - keep_dims=keep_dims + keep_dims=keep_dims, ) set_layer_name(layer, target, name) return layer.get_output(0) diff --git a/tests/py/dynamo/converters/test_amax_aten.py b/tests/py/dynamo/converters/test_amax_aten.py index 5923e6d40e..1e8988cbcc 100644 --- a/tests/py/dynamo/converters/test_amax_aten.py +++ b/tests/py/dynamo/converters/test_amax_aten.py @@ -45,7 +45,7 @@ def forward(self, x): inputs, expected_ops={torch.ops.aten.amax.default}, ) - + @parameterized.expand( [ ((3, 2, 4), 1, True, torch.int, 0, 5), @@ -66,7 +66,7 @@ def forward(self, x): expected_ops={torch.ops.aten.amax.default}, check_dtype=False, ) - + @parameterized.expand( [ ((3, 2, 4), [1], True, torch.int, 0, 5), From 8c580d5651507112b717da8a36ab36bbdc85219b Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 18 Aug 2023 17:37:04 -0700 Subject: [PATCH 3/7] fix bugs in test --- tests/py/dynamo/converters/test_amax_aten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/py/dynamo/converters/test_amax_aten.py b/tests/py/dynamo/converters/test_amax_aten.py index 1e8988cbcc..c12e8aa07f 100644 --- a/tests/py/dynamo/converters/test_amax_aten.py +++ b/tests/py/dynamo/converters/test_amax_aten.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn +from harness import DispatchTestCase from parameterized import parameterized from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.dynamo.test_utils import DispatchTestCase class TestAmaxConverter(DispatchTestCase): From cbe8ff264dc7eb393bb097dfe7b1f5b116dbd20e Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 22 Aug 2023 16:22:55 -0700 Subject: [PATCH 4/7] minor fix --- .../dynamo/conversion/aten_ops_converters.py | 8 +------ .../dynamo/conversion/impl/reduce.py | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index c0d1fb08d5..ae6939a84a 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -430,18 +430,12 @@ def aten_ops_amax( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = args[0] - if (isinstance(input_val, TRTTensor)) and ( - input_val.dtype == trt.int8 or input_val.dtype == trt.int32 - ): - input_val = cast_trt_tensor(network, input_val, trt.float32, name) - return impl.reduce.amax( network, target, SourceIR.ATEN, name, - input_val, + args[0], args[1], args_bounds_check(args, 2, replacement=False), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py index 1ac173794d..7162c54d32 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -1,9 +1,12 @@ -from typing import Any, Optional, Tuple, Union +from typing import Optional, Tuple, Union import tensorrt as trt from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.dynamo.conversion.converter_utils import get_axes_for_reduce_op +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + get_axes_for_reduce_op, +) from torch_tensorrt.fx.converters.converter_utils import set_layer_name from torch_tensorrt.fx.types import TRTNetwork, TRTTensor @@ -13,24 +16,23 @@ def amax( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, + input_val: TRTTensor, dim: Union[int, Tuple[int]], - keep_dims: Optional[bool] = False, - out: Optional[Any] = None, + keepdim: bool = False, ) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"amax received input {input} that is not part of the TensorRT region!" - ) + if (isinstance(input_val, TRTTensor)) and ( + input_val.dtype == trt.int8 or input_val.dtype == trt.int32 + ): + input_val = cast_trt_tensor(network, input_val, trt.float32, name) if dim is None: raise ValueError("amax requires specifying dimension(s) (dim).") layer = network.add_reduce( - input, + input_val, trt.ReduceOperation.MAX, axes=get_axes_for_reduce_op(dim), - keep_dims=keep_dims, + keep_dims=keepdim, ) set_layer_name(layer, target, name) return layer.get_output(0) From 9a1c4265c1464eb58d239837c32c7023a8ea26bc Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 23 Aug 2023 13:13:50 -0700 Subject: [PATCH 5/7] minor fix --- py/torch_tensorrt/dynamo/conversion/impl/reduce.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py index 7162c54d32..53070761dd 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/reduce.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/reduce.py @@ -25,14 +25,11 @@ def amax( ): input_val = cast_trt_tensor(network, input_val, trt.float32, name) - if dim is None: - raise ValueError("amax requires specifying dimension(s) (dim).") - layer = network.add_reduce( input_val, trt.ReduceOperation.MAX, axes=get_axes_for_reduce_op(dim), keep_dims=keepdim, ) - set_layer_name(layer, target, name) + set_layer_name(layer, target, name, source_ir) return layer.get_output(0) From ecb2316a236b540233f6cb768bdac3353b9c9ec4 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 23 Aug 2023 13:36:53 -0700 Subject: [PATCH 6/7] resolve conflicts --- .../dynamo/conversion/aten_ops_converters.py | 392 ++++++++++++++++++ 1 file changed, 392 insertions(+) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index ae6939a84a..d7ae5686eb 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -422,6 +422,24 @@ def aten_ops_clone( ) +@dynamo_tensorrt_converter(torch.ops.aten.expand.default) +def aten_ops_expand( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.expand( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.amax.default) def aten_ops_amax( network: TRTNetwork, @@ -439,3 +457,377 @@ def aten_ops_amax( args[1], args_bounds_check(args, 2, replacement=False), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc] +def aten_ops_exp( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.exp( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.log.default) # type: ignore[misc] +def aten_ops_log( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.log( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.sqrt.default) # type: ignore[misc] +def aten_ops_sqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.sqrt( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.reciprocal.default) # type: ignore[misc] +def aten_ops_recip( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.recip( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.abs.default) # type: ignore[misc] +def aten_ops_abs( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.abs( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.sin.default) # type: ignore[misc] +def aten_ops_sin( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.sin( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.cos.default) # type: ignore[misc] +def aten_ops_cos( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.cos( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.tan.default) # type: ignore[misc] +def aten_ops_tan( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.tan( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.sinh.default) # type: ignore[misc] +def aten_ops_sinh( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.sinh( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.cosh.default) # type: ignore[misc] +def aten_ops_cosh( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.cosh( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.asin.default) # type: ignore[misc] +def aten_ops_asin( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.asin( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.acos.default) # type: ignore[misc] +def aten_ops_acos( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.acos( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.atan.default) # type: ignore[misc] +def aten_ops_atan( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.atan( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.asinh.default) # type: ignore[misc] +def aten_ops_asinh( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.asinh( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.acosh.default) # type: ignore[misc] +def aten_ops_acosh( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.acosh( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.atanh.default) # type: ignore[misc] +def aten_ops_atanh( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.atanh( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.ceil.default) # type: ignore[misc] +def aten_ops_ceil( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.ceil( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.floor.default) # type: ignore[misc] +def aten_ops_floor( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.floor( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.logical_not.default) # type: ignore[misc] +def aten_ops_logical_not( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.logical_not( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.sign.default) # type: ignore[misc] +def aten_ops_sign( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.sign( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.round.default) # type: ignore[misc] +def aten_ops_round( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.round( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.isinf.default) # type: ignore[misc] +def aten_ops_isinf( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.isinf( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) From 53931d4ef4b24fbb17f57092ae8361e817c83f7e Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 24 Aug 2023 16:54:38 -0700 Subject: [PATCH 7/7] fix test bugs and add capability_validator --- .../dynamo/conversion/aten_ops_converters.py | 14 +++++++++++++- tests/py/dynamo/converters/test_amax_aten.py | 8 ++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 44bdb4118c..451d218ee7 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -440,7 +440,19 @@ def aten_ops_expand( ) -@dynamo_tensorrt_converter(torch.ops.aten.amax.default) +def amax_param_validator(amax_node: Node) -> bool: + if len(amax_node.args) < 2: + _LOGGER.debug( + f"At least two args input and dim should be provided, but only got {len(amax_node.args)} args." + ) + return False + + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.amax.default, capability_validator=amax_param_validator +) def aten_ops_amax( network: TRTNetwork, target: Target, diff --git a/tests/py/dynamo/converters/test_amax_aten.py b/tests/py/dynamo/converters/test_amax_aten.py index c12e8aa07f..b6024c83ba 100644 --- a/tests/py/dynamo/converters/test_amax_aten.py +++ b/tests/py/dynamo/converters/test_amax_aten.py @@ -14,12 +14,12 @@ class TestAmaxConverter(DispatchTestCase): ((6, 7, 5, 4, 5), 4, False), ] ) - def test_amax_dim_int_int(self, input_shape, dim, keep_dims, dtype): + def test_amax_dim_int_default(self, input_shape, dim, keep_dims): class Amax(nn.Module): def forward(self, x): return torch.amax(x, dim=dim, keepdim=keep_dims) - inputs = [torch.randn(*input_shape, dtype=dtype)] + inputs = [torch.randn(*input_shape)] self.run_test( Amax(), inputs, @@ -34,12 +34,12 @@ def forward(self, x): ((6, 7, 5, 4, 5), [1, 3, 4], False), ] ) - def test_amax_dim_tuple_int(self, input_shape, dim, keep_dims, dtype): + def test_amax_dim_tuple_default(self, input_shape, dim, keep_dims): class Amax(nn.Module): def forward(self, x): return torch.amax(x, dim=dim, keepdim=keep_dims) - inputs = [torch.randn(*input_shape, dtype=dtype)] + inputs = [torch.randn(*input_shape)] self.run_test( Amax(), inputs,