1
1
import torch
2
2
from parameterized import parameterized
3
3
from torch .testing ._internal .common_utils import run_tests
4
+ from torch_tensorrt import Input
4
5
5
6
from .harness import DispatchTestCase
6
7
@@ -27,6 +28,7 @@ def forward(self, input):
27
28
self .run_test (
28
29
TestChunk (),
29
30
input ,
31
+ use_dynamo_tracer = True ,
30
32
)
31
33
32
34
@parameterized .expand (
@@ -51,6 +53,7 @@ def forward(self, input):
51
53
self .run_test (
52
54
TestChunk (),
53
55
input ,
56
+ use_dynamo_tracer = True ,
54
57
)
55
58
56
59
@parameterized .expand (
@@ -75,8 +78,102 @@ def forward(self, input):
75
78
self .run_test (
76
79
TestChunk (),
77
80
input ,
81
+ use_dynamo_tracer = True ,
78
82
)
79
83
84
+ #######################Dynamic cases################
85
+
86
+ # @parameterized.expand(
87
+ # [
88
+ # ((1,), (1,), (3,), 3, 0),
89
+ # ((3,), (3,), (4,), 3, 0),
90
+ # ((4,), (4,), (6,), 3, 0),
91
+ # ((6,), (6,), (9,), 3, 0),
92
+ # ((3,), (3,), (4,), 1, -1),
93
+ # ((3,), (3,), (4,), 3, -1),
94
+ # ((3,), (3,), (4,), 4, -1),
95
+ # ]
96
+ # )
97
+ # def test_chunk_1D(self, min_shape, opt_shape, max_shape, chunks, dim):
98
+ # class TestChunk(torch.nn.Module):
99
+ # def forward(self, input):
100
+ # out = torch.ops.aten.chunk.default(input, chunks, dim)
101
+ # return out
102
+
103
+ # input_specs = [
104
+ # Input(
105
+ # min_shape=min_shape,
106
+ # opt_shape=opt_shape,
107
+ # max_shape=max_shape,
108
+ # ),
109
+ # ]
110
+ # self.run_test_with_dynamic_shape(
111
+ # TestChunk(),
112
+ # input_specs,
113
+ # use_dynamo_tracer = True,
114
+ # )
115
+
116
+ # @parameterized.expand(
117
+ # [
118
+ # ((3, 4), (3, 4), (4, 4), 1, 0),
119
+ # ((3, 4), (3, 4), (4, 4), 3, 0),
120
+ # ((3, 4), (3, 4), (4, 4), 4, 0),
121
+ # ((3, 4), (3, 4), (4, 4), 2, -2),
122
+ # ((3, 4), (3, 4), (4, 4), 6, -2),
123
+ # ((3, 4), (3, 4), (4, 4), 3, 1),
124
+ # ((3, 4), (3, 4), (4, 4), 4, 1),
125
+ # ((3, 4), (3, 4), (4, 4), 5, -1),
126
+ # ]
127
+ # )
128
+ # def test_chunk_2D(self, min_shape, opt_shape, max_shape, chunks, dim):
129
+ # class TestChunk(torch.nn.Module):
130
+ # def forward(self, input):
131
+ # out = torch.ops.aten.chunk.default(input, chunks, dim)
132
+ # return out
133
+
134
+ # input_specs = [
135
+ # Input(
136
+ # min_shape=min_shape,
137
+ # opt_shape=opt_shape,
138
+ # max_shape=max_shape,
139
+ # ),
140
+ # ]
141
+ # self.run_test_with_dynamic_shape(
142
+ # TestChunk(),
143
+ # input_specs,
144
+ # use_dynamo_tracer = True,
145
+ # )
146
+
147
+ # @parameterized.expand(
148
+ # [
149
+ # ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 0),
150
+ # ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -3),
151
+ # ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, 1),
152
+ # ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, 1),
153
+ # ((3, 4, 2), (3, 4, 2), (4, 4, 2), 6, -2),
154
+ # ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 2),
155
+ # ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -1),
156
+ # ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, -1),
157
+ # ]
158
+ # )
159
+ # def test_chunk_3D(self, min_shape, opt_shape, max_shape, chunks, dim):
160
+ # class TestChunk(torch.nn.Module):
161
+ # def forward(self, input):
162
+ # out = torch.ops.aten.chunk.default(input, chunks, dim)
163
+ # return out
164
+
165
+ # input_specs = [
166
+ # Input(
167
+ # min_shape=min_shape,
168
+ # opt_shape=opt_shape,
169
+ # max_shape=max_shape,
170
+ # ),
171
+ # ]
172
+ # self.run_test_with_dynamic_shape(
173
+ # TestChunk(),
174
+ # input_specs,
175
+ # use_dynamo_tracer = True,
176
+ # )
80
177
81
178
if __name__ == "__main__" :
82
179
run_tests ()
0 commit comments