Skip to content

Commit 958f1d1

Browse files
committed
chunk_validator
1 parent 4dbeafd commit 958f1d1

File tree

2 files changed

+102
-2
lines changed

2 files changed

+102
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ def aten_ops_softmax(
692692

693693
@dynamo_tensorrt_converter(
694694
torch.ops.aten.split.Tensor,
695-
capability_validator=has_static_shapes_in_args([1]),
695+
capability_validator= (has_static_shapes_in_args([0]) and has_static_shapes_in_args([1])),
696696
supports_dynamic_shapes=True,
697697
)
698698
@dynamo_tensorrt_converter(
@@ -903,7 +903,10 @@ def aten_ops_slice(
903903
)
904904

905905

906-
@dynamo_tensorrt_converter(torch.ops.aten.chunk.default)
906+
@dynamo_tensorrt_converter(
907+
torch.ops.aten.chunk.default,
908+
supports_dynamic_shapes=True,
909+
)
907910
@enforce_tensor_types(
908911
{
909912
0: (TRTTensor,),

tests/py/dynamo/conversion/test_chunk_aten.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from parameterized import parameterized
33
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
45

56
from .harness import DispatchTestCase
67

@@ -27,6 +28,7 @@ def forward(self, input):
2728
self.run_test(
2829
TestChunk(),
2930
input,
31+
use_dynamo_tracer=True,
3032
)
3133

3234
@parameterized.expand(
@@ -51,6 +53,7 @@ def forward(self, input):
5153
self.run_test(
5254
TestChunk(),
5355
input,
56+
use_dynamo_tracer=True,
5457
)
5558

5659
@parameterized.expand(
@@ -75,8 +78,102 @@ def forward(self, input):
7578
self.run_test(
7679
TestChunk(),
7780
input,
81+
use_dynamo_tracer=True,
7882
)
7983

84+
#######################Dynamic cases################
85+
86+
# @parameterized.expand(
87+
# [
88+
# ((1,), (1,), (3,), 3, 0),
89+
# ((3,), (3,), (4,), 3, 0),
90+
# ((4,), (4,), (6,), 3, 0),
91+
# ((6,), (6,), (9,), 3, 0),
92+
# ((3,), (3,), (4,), 1, -1),
93+
# ((3,), (3,), (4,), 3, -1),
94+
# ((3,), (3,), (4,), 4, -1),
95+
# ]
96+
# )
97+
# def test_chunk_1D(self, min_shape, opt_shape, max_shape, chunks, dim):
98+
# class TestChunk(torch.nn.Module):
99+
# def forward(self, input):
100+
# out = torch.ops.aten.chunk.default(input, chunks, dim)
101+
# return out
102+
103+
# input_specs = [
104+
# Input(
105+
# min_shape=min_shape,
106+
# opt_shape=opt_shape,
107+
# max_shape=max_shape,
108+
# ),
109+
# ]
110+
# self.run_test_with_dynamic_shape(
111+
# TestChunk(),
112+
# input_specs,
113+
# use_dynamo_tracer = True,
114+
# )
115+
116+
# @parameterized.expand(
117+
# [
118+
# ((3, 4), (3, 4), (4, 4), 1, 0),
119+
# ((3, 4), (3, 4), (4, 4), 3, 0),
120+
# ((3, 4), (3, 4), (4, 4), 4, 0),
121+
# ((3, 4), (3, 4), (4, 4), 2, -2),
122+
# ((3, 4), (3, 4), (4, 4), 6, -2),
123+
# ((3, 4), (3, 4), (4, 4), 3, 1),
124+
# ((3, 4), (3, 4), (4, 4), 4, 1),
125+
# ((3, 4), (3, 4), (4, 4), 5, -1),
126+
# ]
127+
# )
128+
# def test_chunk_2D(self, min_shape, opt_shape, max_shape, chunks, dim):
129+
# class TestChunk(torch.nn.Module):
130+
# def forward(self, input):
131+
# out = torch.ops.aten.chunk.default(input, chunks, dim)
132+
# return out
133+
134+
# input_specs = [
135+
# Input(
136+
# min_shape=min_shape,
137+
# opt_shape=opt_shape,
138+
# max_shape=max_shape,
139+
# ),
140+
# ]
141+
# self.run_test_with_dynamic_shape(
142+
# TestChunk(),
143+
# input_specs,
144+
# use_dynamo_tracer = True,
145+
# )
146+
147+
# @parameterized.expand(
148+
# [
149+
# ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 0),
150+
# ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -3),
151+
# ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, 1),
152+
# ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, 1),
153+
# ((3, 4, 2), (3, 4, 2), (4, 4, 2), 6, -2),
154+
# ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 2),
155+
# ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -1),
156+
# ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, -1),
157+
# ]
158+
# )
159+
# def test_chunk_3D(self, min_shape, opt_shape, max_shape, chunks, dim):
160+
# class TestChunk(torch.nn.Module):
161+
# def forward(self, input):
162+
# out = torch.ops.aten.chunk.default(input, chunks, dim)
163+
# return out
164+
165+
# input_specs = [
166+
# Input(
167+
# min_shape=min_shape,
168+
# opt_shape=opt_shape,
169+
# max_shape=max_shape,
170+
# ),
171+
# ]
172+
# self.run_test_with_dynamic_shape(
173+
# TestChunk(),
174+
# input_specs,
175+
# use_dynamo_tracer = True,
176+
# )
80177

81178
if __name__ == "__main__":
82179
run_tests()

0 commit comments

Comments
 (0)