Skip to content

Commit 5bd948f

Browse files
authored
feat: dynamic shape support for adaptive_avg_poolNd (partially) (#3021)
1 parent 8536289 commit 5bd948f

File tree

3 files changed

+112
-35
lines changed

3 files changed

+112
-35
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2603,7 +2603,9 @@ def aten_ops_avg_pool(
26032603
)
26042604

26052605

2606-
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool1d.default)
2606+
@dynamo_tensorrt_converter(
2607+
torch.ops.aten.adaptive_avg_pool1d.default, supports_dynamic_shapes=True
2608+
)
26072609
@enforce_tensor_types(
26082610
{
26092611
0: (TRTTensor,),
@@ -2626,10 +2628,18 @@ def aten_ops_adaptive_avg_pool1d(
26262628
)
26272629

26282630

2629-
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool2d.default)
2630-
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
2631-
@dynamo_tensorrt_converter(torch.ops.aten.adaptive_avg_pool3d.default)
2632-
@dynamo_tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
2631+
@dynamo_tensorrt_converter(
2632+
torch.ops.aten.adaptive_avg_pool2d.default, supports_dynamic_shapes=True
2633+
)
2634+
@dynamo_tensorrt_converter(
2635+
torch.ops.aten._adaptive_avg_pool2d.default, supports_dynamic_shapes=True
2636+
)
2637+
@dynamo_tensorrt_converter(
2638+
torch.ops.aten.adaptive_avg_pool3d.default, supports_dynamic_shapes=True
2639+
)
2640+
@dynamo_tensorrt_converter(
2641+
torch.ops.aten._adaptive_avg_pool3d.default, supports_dynamic_shapes=True
2642+
)
26332643
@enforce_tensor_types(
26342644
{
26352645
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/pool.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
124124
"""Calculate the end index of each pooling window"""
125125
return math.ceil((float(idx + 1) * float(in_dim)) / out_dim)
126126

127+
if has_dynamic_shape(input.shape):
128+
assert (
129+
input.shape[-1] != -1 and input.shape[-2] != -1
130+
), "Last 2 dimensions can't be dynamic for adaptive_avg_pool1d."
131+
127132
in_dim = input.shape[-1]
128133
out_dim = output_size if isinstance(output_size, int) else output_size[0]
129134
output_list = []
@@ -179,6 +184,18 @@ def adaptive_avg_poolNd(
179184
input: TRTTensor,
180185
output_size: Sequence[int],
181186
) -> TRTTensor:
187+
if has_dynamic_shape(input.shape):
188+
if len(output_size) == 2: # adaptive_avg_pool2d
189+
assert (
190+
input.shape[-1] != -1 and input.shape[-2] != -1
191+
), "Last 2 dimensions can't be dynamic for adaptive_avg_pool2d."
192+
elif len(output_size) == 3: # adaptive_avg_pool3d
193+
assert (
194+
input.shape[-1] != -1
195+
and input.shape[-2] != -1
196+
and input.shape[-3] != -1
197+
), "Last 3 dimensions can't be dynamic for adaptive_avg_pool3d."
198+
182199
input_shape = input.shape
183200
input_rank = len(input_shape)
184201
output_rank = len(output_size)

tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py

Lines changed: 80 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,40 @@ def forward(self, x):
7979
enable_passes=True,
8080
)
8181

82+
@parameterized.expand(
83+
[
84+
(
85+
(1, 3, 3),
86+
(2, 3, 3),
87+
(3, 3, 3),
88+
torch.float,
89+
(2,),
90+
),
91+
]
92+
)
93+
def test_dynamic_shape_adaptive_pool1d(
94+
self,
95+
min_shape,
96+
opt_shape,
97+
max_shape,
98+
type,
99+
output_size,
100+
):
101+
class adaptive_pool1d(torch.nn.Module):
102+
def forward(self, x):
103+
return torch.ops.aten.adaptive_avg_pool1d.default(x, output_size)
104+
105+
input_specs = [
106+
Input(
107+
min_shape=min_shape,
108+
opt_shape=opt_shape,
109+
max_shape=max_shape,
110+
dtype=type,
111+
),
112+
]
113+
114+
self.run_test_with_dynamic_shape(adaptive_pool1d(), input_specs)
115+
82116
@parameterized.expand(
83117
[
84118
# 3d input
@@ -159,29 +193,37 @@ def forward(self, x):
159193

160194
@parameterized.expand(
161195
[
162-
((1, 2),),
196+
(
197+
(1, 1, 3, 3),
198+
(2, 2, 3, 3),
199+
(3, 3, 3, 3),
200+
torch.float,
201+
(2, 2),
202+
),
163203
]
164204
)
165-
def test_adaptive_avg_pool2d_dynamic(self, output_size):
166-
class TestModule(torch.nn.Module):
167-
def __init__(self):
168-
super().__init__()
169-
205+
def test_dynamic_shape_adaptive_pool2d(
206+
self,
207+
min_shape,
208+
opt_shape,
209+
max_shape,
210+
type,
211+
output_size,
212+
):
213+
class adaptive_pool2d(torch.nn.Module):
170214
def forward(self, x):
171-
out = torch.ops.aten.adaptive_avg_pool2d.default(x, output_size)
172-
return out
215+
return torch.ops.aten.adaptive_avg_pool2d.default(x, output_size)
173216

174217
input_specs = [
175218
Input(
176-
shape=(-1, 2, 3, 2),
177-
dtype=torch.float32,
178-
shape_ranges=[((1, 2, 3, 2), (3, 2, 3, 2), (10, 2, 3, 2))],
219+
min_shape=min_shape,
220+
opt_shape=opt_shape,
221+
max_shape=max_shape,
222+
dtype=type,
179223
),
180224
]
181-
self.run_test_with_dynamic_shape(
182-
TestModule(),
183-
input_specs,
184-
)
225+
226+
self.run_test_with_dynamic_shape(adaptive_pool2d(), input_specs)
185227

186228
@parameterized.expand(
187229
[
@@ -271,29 +313,37 @@ def forward(self, x):
271313

272314
@parameterized.expand(
273315
[
274-
((1, 2, 3),),
316+
(
317+
(1, 1, 3, 3, 3),
318+
(2, 2, 3, 3, 3),
319+
(3, 3, 3, 3, 3),
320+
torch.float,
321+
(2, 2, 2),
322+
),
275323
]
276324
)
277-
def test_adaptive_avg_pool3d_dynamic(self, output_size):
278-
class TestModule(torch.nn.Module):
279-
def __init__(self):
280-
super().__init__()
281-
325+
def test_dynamic_shape_adaptive_pool3d(
326+
self,
327+
min_shape,
328+
opt_shape,
329+
max_shape,
330+
type,
331+
output_size,
332+
):
333+
class adaptive_pool3d(torch.nn.Module):
282334
def forward(self, x):
283-
out = torch.ops.aten.adaptive_avg_pool3d.default(x, output_size)
284-
return out
335+
return torch.ops.aten.adaptive_avg_pool3d.default(x, output_size)
285336

286337
input_specs = [
287338
Input(
288-
shape=(-1, 2, 3, 1, 4),
289-
dtype=torch.float32,
290-
shape_ranges=[((1, 2, 3, 1, 4), (3, 2, 3, 1, 4), (10, 2, 3, 1, 4))],
339+
min_shape=min_shape,
340+
opt_shape=opt_shape,
341+
max_shape=max_shape,
342+
dtype=type,
291343
),
292344
]
293-
self.run_test_with_dynamic_shape(
294-
TestModule(),
295-
input_specs,
296-
)
345+
346+
self.run_test_with_dynamic_shape(adaptive_pool3d(), input_specs)
297347

298348

299349
if __name__ == "__main__":

0 commit comments

Comments
 (0)