1
+ import unittest
2
+
1
3
import torch
2
4
from parameterized import parameterized
3
5
from torch .testing ._internal .common_utils import run_tests
6
+ from torch_tensorrt import Input
4
7
5
8
from .harness import DispatchTestCase
6
9
@@ -27,6 +30,7 @@ def forward(self, input):
27
30
self .run_test (
28
31
TestChunk (),
29
32
input ,
33
+ use_dynamo_tracer = True ,
30
34
)
31
35
32
36
@parameterized .expand (
@@ -51,6 +55,7 @@ def forward(self, input):
51
55
self .run_test (
52
56
TestChunk (),
53
57
input ,
58
+ use_dynamo_tracer = True ,
54
59
)
55
60
56
61
@parameterized .expand (
@@ -75,6 +80,106 @@ def forward(self, input):
75
80
self .run_test (
76
81
TestChunk (),
77
82
input ,
83
+ use_dynamo_tracer = True ,
84
+ )
85
+
86
+
87
+ #######################Dynamic cases#######################
88
+ # The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed
89
+ @unittest .skip (
90
+ "Pending aten.split dynamic input torch.export guard bug. Issue- https://github.com/pytorch/pytorch/issues/134663"
91
+ )
92
+ class TestChunkDynamicConverter (DispatchTestCase ):
93
+ @parameterized .expand (
94
+ [
95
+ ((1 ,), (1 ,), (3 ,), 3 , 0 ),
96
+ ((3 ,), (3 ,), (4 ,), 3 , 0 ),
97
+ ((4 ,), (4 ,), (6 ,), 3 , 0 ),
98
+ ((6 ,), (6 ,), (9 ,), 3 , 0 ),
99
+ ((3 ,), (3 ,), (4 ,), 1 , - 1 ),
100
+ ((3 ,), (3 ,), (4 ,), 3 , - 1 ),
101
+ ((3 ,), (3 ,), (4 ,), 4 , - 1 ),
102
+ ]
103
+ )
104
+ def test_chunk_1D (self , min_shape , opt_shape , max_shape , chunks , dim ):
105
+ class TestChunk (torch .nn .Module ):
106
+ def forward (self , input ):
107
+ out = torch .ops .aten .chunk .default (input , chunks , dim )
108
+ return out
109
+
110
+ input_specs = [
111
+ Input (
112
+ min_shape = min_shape ,
113
+ opt_shape = opt_shape ,
114
+ max_shape = max_shape ,
115
+ ),
116
+ ]
117
+ self .run_test_with_dynamic_shape (
118
+ TestChunk (),
119
+ input_specs ,
120
+ use_dynamo_tracer = True ,
121
+ )
122
+
123
+ @parameterized .expand (
124
+ [
125
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 1 , 0 ),
126
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 3 , 0 ),
127
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 4 , 0 ),
128
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 2 , - 2 ),
129
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 6 , - 2 ),
130
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 3 , 1 ),
131
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 4 , 1 ),
132
+ ((3 , 4 ), (3 , 4 ), (4 , 4 ), 5 , - 1 ),
133
+ ]
134
+ )
135
+ def test_chunk_2D (self , min_shape , opt_shape , max_shape , chunks , dim ):
136
+ class TestChunk (torch .nn .Module ):
137
+ def forward (self , input ):
138
+ out = torch .ops .aten .chunk .default (input , chunks , dim )
139
+ return out
140
+
141
+ input_specs = [
142
+ Input (
143
+ min_shape = min_shape ,
144
+ opt_shape = opt_shape ,
145
+ max_shape = max_shape ,
146
+ ),
147
+ ]
148
+ self .run_test_with_dynamic_shape (
149
+ TestChunk (),
150
+ input_specs ,
151
+ use_dynamo_tracer = True ,
152
+ )
153
+
154
+ @parameterized .expand (
155
+ [
156
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 1 , 0 ),
157
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 3 , - 3 ),
158
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 3 , 1 ),
159
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 4 , 1 ),
160
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 6 , - 2 ),
161
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 1 , 2 ),
162
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 3 , - 1 ),
163
+ ((3 , 4 , 2 ), (3 , 4 , 2 ), (4 , 4 , 2 ), 4 , - 1 ),
164
+ ]
165
+ )
166
+ def test_chunk_3D (self , min_shape , opt_shape , max_shape , chunks , dim ):
167
+ class TestChunk (torch .nn .Module ):
168
+ def forward (self , input ):
169
+ out = torch .ops .aten .chunk .default (input , chunks , dim )
170
+ return out
171
+
172
+ input_specs = [
173
+ Input (
174
+ min_shape = min_shape ,
175
+ opt_shape = opt_shape ,
176
+ max_shape = max_shape ,
177
+ ),
178
+ ]
179
+ self .run_test_with_dynamic_shape (
180
+ TestChunk (),
181
+ input_specs ,
182
+ use_dynamo_tracer = True ,
78
183
)
79
184
80
185
0 commit comments