Skip to content

Commit 5d58608

Browse files
keehyunaperi044
authored andcommitted
feat: Add dynamic shape support for sub (#2888)
1 parent b22276f commit 5d58608

File tree

2 files changed

+98
-3
lines changed

2 files changed

+98
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,8 +1732,8 @@ def aten_ops_minimum(
17321732
)
17331733

17341734

1735-
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
1736-
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar)
1735+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True)
1736+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True)
17371737
def aten_ops_sub(
17381738
ctx: ConversionContext,
17391739
target: Target,
@@ -1749,7 +1749,7 @@ def aten_ops_sub(
17491749
ctx,
17501750
target,
17511751
SourceIR.ATEN,
1752-
name,
1752+
name + "_alpha",
17531753
other,
17541754
alpha,
17551755
)

tests/py/dynamo/conversion/test_sub_aten.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,101 @@ def forward(self, lhs_val):
7676
inputs,
7777
)
7878

79+
@parameterized.expand(
80+
[
81+
(
82+
"3d_2d_alpha_float32",
83+
torch.float32,
84+
(1, 1, 1),
85+
(3, 2, 2),
86+
(3, 2, 4),
87+
(1, 1),
88+
(2, 2),
89+
(2, 4),
90+
1.5,
91+
),
92+
(
93+
"2d_2d_alpha_int32",
94+
torch.int32,
95+
(3, 2),
96+
(3, 2),
97+
(3, 3),
98+
(3, 2),
99+
(3, 2),
100+
(3, 3),
101+
2,
102+
),
103+
]
104+
)
105+
def test_dynamic_shape_sub(self, *args):
106+
class sub(nn.Module):
107+
def forward(self, lhs_val, rhs_val):
108+
return torch.ops.aten.sub.Tensor(lhs_val, rhs_val, alpha=args[8])
109+
110+
input_specs = [
111+
Input(
112+
min_shape=args[2],
113+
opt_shape=args[3],
114+
max_shape=args[4],
115+
dtype=args[1],
116+
),
117+
Input(
118+
min_shape=args[5],
119+
opt_shape=args[6],
120+
max_shape=args[7],
121+
dtype=args[1],
122+
),
123+
]
124+
125+
self.run_test_with_dynamic_shape(sub(), input_specs)
126+
127+
@parameterized.expand(
128+
[
129+
(
130+
"3d_scalar_float32",
131+
torch.float32,
132+
(1, 1, 1),
133+
(3, 2, 2),
134+
(3, 2, 4),
135+
0.3,
136+
)
137+
]
138+
)
139+
def test_dynamic_shape_sub_scalar(self, *args):
140+
class sub(nn.Module):
141+
def forward(self, lhs_val):
142+
return torch.ops.aten.sub.Tensor(lhs_val, args[5])
143+
144+
input_specs = [
145+
Input(
146+
min_shape=args[2],
147+
opt_shape=args[3],
148+
max_shape=args[4],
149+
dtype=args[1],
150+
),
151+
]
152+
153+
self.run_test_with_dynamic_shape(sub(), input_specs)
154+
155+
@parameterized.expand(
156+
[("scalar_2d_alpha_float32", torch.float32, (1, 1), (2, 2), (3, 4), 0.3, 1.5)]
157+
)
158+
def test_dynamic_shape_sub_scalar_alpha(self, *args):
159+
class sub(nn.Module):
160+
def forward(self, rhs_val):
161+
return torch.ops.aten.sub.Tensor(args[5], rhs_val, alpha=args[6])
162+
163+
input_specs = [
164+
Input(
165+
min_shape=args[2],
166+
opt_shape=args[3],
167+
max_shape=args[4],
168+
dtype=args[1],
169+
),
170+
]
171+
172+
self.run_test_with_dynamic_shape(sub(), input_specs)
173+
79174

80175
if __name__ == "__main__":
81176
run_tests()

0 commit comments

Comments
 (0)