Skip to content

Commit 050272f

Browse files
add dynamic shape support for aten.ops.gt and aten.ops.ge (#2883)
1 parent 0d93425 commit 050272f

File tree

3 files changed

+106
-4
lines changed

3 files changed

+106
-4
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2161,8 +2161,8 @@ def aten_ops_ne(
21612161
)
21622162

21632163

2164-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
2165-
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
2164+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor, supports_dynamic_shapes=True)
2165+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar, supports_dynamic_shapes=True)
21662166
@enforce_tensor_types(
21672167
{
21682168
0: (TRTTensor,),
@@ -2185,8 +2185,8 @@ def aten_ops_gt(
21852185
)
21862186

21872187

2188-
@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor)
2189-
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar)
2188+
@dynamo_tensorrt_converter(torch.ops.aten.ge.Tensor, supports_dynamic_shapes=True)
2189+
@dynamo_tensorrt_converter(torch.ops.aten.ge.Scalar, supports_dynamic_shapes=True)
21902190
@enforce_tensor_types(
21912191
{
21922192
0: (TRTTensor,),

tests/py/dynamo/conversion/test_ge_aten.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -61,6 +62,56 @@ def forward(self, lhs_val):
6162
inputs,
6263
)
6364

65+
@parameterized.expand(
66+
[
67+
("2d_2d", (5, 3), (5, 1)),
68+
("3d_2d", (5, 3, 2), (3, 1)),
69+
("4d_3d", (5, 3, 4, 1), (3, 1, 1)),
70+
]
71+
)
72+
def test_ge_tensor_broadcast(self, _, lshape, rshape):
73+
class ge(nn.Module):
74+
def forward(self, lhs_val, rhs_val):
75+
return torch.ops.aten.ge.Tensor(lhs_val, rhs_val)
76+
77+
inputs = [
78+
torch.randint(0, 3, lshape, dtype=torch.int32),
79+
torch.randint(0, 3, rshape, dtype=torch.int32),
80+
]
81+
self.run_test(
82+
ge(),
83+
inputs,
84+
)
85+
86+
@parameterized.expand(
87+
[
88+
("2d_2d", (2, 3), (4, 3), (5, 3), (2, 3), (4, 3), (5, 3)),
89+
("3d_2d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (2, 1), (3, 1), (4, 1)),
90+
]
91+
)
92+
def test_ge_dynamic_tensor(self, *args):
93+
class ge(nn.Module):
94+
def forward(self, lhs_val, rhs_val):
95+
return torch.ops.aten.ge.Tensor(lhs_val, rhs_val)
96+
97+
input_specs = [
98+
Input(
99+
min_shape=args[1],
100+
opt_shape=args[2],
101+
max_shape=args[3],
102+
),
103+
Input(
104+
min_shape=args[4],
105+
opt_shape=args[5],
106+
max_shape=args[6],
107+
),
108+
]
109+
110+
self.run_test_with_dynamic_shape(
111+
ge(),
112+
input_specs,
113+
)
114+
64115

65116
if __name__ == "__main__":
66117
run_tests()

tests/py/dynamo/conversion/test_gt_aten.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -58,6 +59,56 @@ def forward(self, lhs_val):
5859
inputs,
5960
)
6061

62+
@parameterized.expand(
63+
[
64+
("2d_2d", (5, 3), (5, 1)),
65+
("3d_2d", (5, 3, 2), (3, 1)),
66+
("4d_3d", (5, 3, 4, 1), (3, 1, 1)),
67+
]
68+
)
69+
def test_gt_tensor_broadcast(self, _, lshape, rshape):
70+
class gt(nn.Module):
71+
def forward(self, lhs_val, rhs_val):
72+
return torch.ops.aten.gt.Tensor(lhs_val, rhs_val)
73+
74+
inputs = [
75+
torch.randint(0, 3, lshape, dtype=torch.int32),
76+
torch.randint(0, 3, rshape, dtype=torch.int32),
77+
]
78+
self.run_test(
79+
gt(),
80+
inputs,
81+
)
82+
83+
@parameterized.expand(
84+
[
85+
("2d_2d", (2, 3), (4, 3), (5, 3), (2, 3), (4, 3), (5, 3)),
86+
("3d_2d", (2, 2, 2), (2, 3, 2), (2, 4, 2), (2, 1), (3, 1), (4, 1)),
87+
]
88+
)
89+
def test_gt_dynamic_tensor(self, *args):
90+
class gt(nn.Module):
91+
def forward(self, lhs_val, rhs_val):
92+
return torch.ops.aten.gt.Tensor(lhs_val, rhs_val)
93+
94+
input_specs = [
95+
Input(
96+
min_shape=args[1],
97+
opt_shape=args[2],
98+
max_shape=args[3],
99+
),
100+
Input(
101+
min_shape=args[4],
102+
opt_shape=args[5],
103+
max_shape=args[6],
104+
),
105+
]
106+
107+
self.run_test_with_dynamic_shape(
108+
gt(),
109+
input_specs,
110+
)
111+
61112

62113
if __name__ == "__main__":
63114
run_tests()

0 commit comments

Comments
 (0)