Skip to content

Commit e3608a5

Browse files
authored
feat: support aten.amin dynamo converter (#2504)
1 parent 2f9b259 commit e3608a5

File tree

3 files changed

+153
-14
lines changed

3 files changed

+153
-14
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -955,18 +955,11 @@ def aten_ops_expand(
955955
)
956956

957957

958-
def amax_param_validator(amax_node: Node) -> bool:
959-
if len(amax_node.args) < 2:
960-
_LOGGER.debug(
961-
f"At least two args input and dim should be provided, but only got {len(amax_node.args)} args."
962-
)
963-
return False
964-
965-
return True
966-
967-
968-
@dynamo_tensorrt_converter(
969-
torch.ops.aten.amax.default, capability_validator=amax_param_validator
958+
@dynamo_tensorrt_converter(torch.ops.aten.amax.default)
959+
@enforce_tensor_types(
960+
{
961+
0: (TRTTensor,),
962+
}
970963
)
971964
def aten_ops_amax(
972965
ctx: ConversionContext,
@@ -986,6 +979,30 @@ def aten_ops_amax(
986979
)
987980

988981

982+
@dynamo_tensorrt_converter(torch.ops.aten.amin.default)
983+
@enforce_tensor_types(
984+
{
985+
0: (TRTTensor,),
986+
}
987+
)
988+
def aten_ops_amin(
989+
ctx: ConversionContext,
990+
target: Target,
991+
args: Tuple[Argument, ...],
992+
kwargs: Dict[str, Argument],
993+
name: str,
994+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
995+
return impl.reduce.amin(
996+
ctx,
997+
target,
998+
SourceIR.ATEN,
999+
name,
1000+
args[0],
1001+
args_bounds_check(args, 1, replacement=[]),
1002+
args_bounds_check(args, 2, replacement=False),
1003+
)
1004+
1005+
9891006
@dynamo_tensorrt_converter(torch.ops.aten.sum.default)
9901007
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList)
9911008
@dynamo_tensorrt_converter(torch.ops.prims.sum.default)

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ def amax(
1919
source_ir: Optional[SourceIR],
2020
name: str,
2121
input_val: TRTTensor,
22-
dim: Union[int, Sequence[int]],
22+
dim: Sequence[int] = [],
2323
keepdim: bool = False,
2424
) -> TRTTensor:
2525
if (isinstance(input_val, TRTTensor)) and (
2626
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
2727
):
2828
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
2929

30-
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
30+
if isinstance(dim, (tuple, list)) and len(dim) == 0:
3131
dim = tuple(range(len(input_val.shape)))
3232

3333
layer = ctx.net.add_reduce(
@@ -40,6 +40,33 @@ def amax(
4040
return layer.get_output(0)
4141

4242

43+
def amin(
44+
ctx: ConversionContext,
45+
target: Target,
46+
source_ir: Optional[SourceIR],
47+
name: str,
48+
input_val: TRTTensor,
49+
dim: Sequence[int] = [],
50+
keepdim: bool = False,
51+
) -> TRTTensor:
52+
if (isinstance(input_val, TRTTensor)) and (
53+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
54+
):
55+
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
56+
57+
if isinstance(dim, (tuple, list)) and len(dim) == 0:
58+
dim = tuple(range(len(input_val.shape)))
59+
60+
layer = ctx.net.add_reduce(
61+
input_val,
62+
trt.ReduceOperation.MIN,
63+
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
64+
keep_dims=keepdim,
65+
)
66+
set_layer_name(layer, target, name, source_ir)
67+
return layer.get_output(0)
68+
69+
4370
def sum(
4471
ctx: ConversionContext,
4572
target: Target,
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestAminConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3, 2, 4), 1, True),
13+
((2, 3, 4, 5), 3, True),
14+
((2, 3, 4, 5), 2, False),
15+
((6, 7, 5, 4, 5), 4, False),
16+
((1, 5, 2, 1), -1, True),
17+
]
18+
)
19+
def test_amin_dim_int_default(self, input_shape, dim, keep_dims):
20+
class Amin(nn.Module):
21+
def forward(self, x):
22+
return torch.ops.aten.amin.default(x, dim, keep_dims)
23+
24+
inputs = [torch.randn(*input_shape)]
25+
self.run_test(
26+
Amin(),
27+
inputs,
28+
)
29+
30+
@parameterized.expand(
31+
[
32+
((1, 2, 4), [], True),
33+
((3, 2, 4), [1], True),
34+
((2, 1, 4, 5), [0, 3], True),
35+
((2, 3, 4, 5), [0, 1, 2, 3], False),
36+
((6, 7, 5, 4, 5), [1, 3, 4], False),
37+
]
38+
)
39+
def test_amin_dim_tuple_default(self, input_shape, dim, keep_dims):
40+
class Amin(nn.Module):
41+
def forward(self, x):
42+
return torch.ops.aten.amin.default(x, dim, keep_dims)
43+
44+
inputs = [torch.randn(*input_shape)]
45+
self.run_test(
46+
Amin(),
47+
inputs,
48+
)
49+
50+
@parameterized.expand(
51+
[
52+
((3, 2, 4), 1, True, torch.int, 0, 5),
53+
((2, 3, 4, 5), 3, True, torch.int, -10, 10),
54+
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
55+
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
56+
((1, 5, 2, 1), -4, False, torch.int32, -5, 5),
57+
]
58+
)
59+
def test_amin_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
60+
class Amin(nn.Module):
61+
def forward(self, x):
62+
return torch.ops.aten.amin.default(x, dim, keep_dims)
63+
64+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
65+
self.run_test(
66+
Amin(),
67+
inputs,
68+
check_dtype=False,
69+
)
70+
71+
@parameterized.expand(
72+
[
73+
((1, 2, 4), [], True, torch.int, 0, 5),
74+
((3, 2, 4), [1], True, torch.int, 0, 5),
75+
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
76+
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
77+
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
78+
((1, 5, 2, 1), [-3, -1], False, torch.int32, -5, 5),
79+
]
80+
)
81+
def test_amin_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high):
82+
class Amin(nn.Module):
83+
def forward(self, x):
84+
return torch.ops.aten.amin.default(x, dim, keep_dims)
85+
86+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
87+
self.run_test(
88+
Amin(),
89+
inputs,
90+
check_dtype=False,
91+
)
92+
93+
94+
if __name__ == "__main__":
95+
run_tests()

0 commit comments

Comments
 (0)