22
33import torch
44import torch .nn as nn
5- from harness import DispatchTestCase
65from torch .testing ._internal .common_utils import run_tests
76from torch_tensorrt import Input
87
8+ from .harness import DispatchTestCase
9+
910
1011class TestIndexConverter (DispatchTestCase ):
11- def test_index_zero (self ):
12+ def test_index_zero_two_dim (self ):
1213 class TestModule (nn .Module ):
14+ def __init__ (self ):
15+ self .index0 = torch .randint (0 , 1 , (1 , 1 ))
16+ super ().__init__ ()
17+
1318 def forward (self , x ):
1419 index0 = torch .randint (0 , 1 , (1 , 1 ))
15- indices = [None , index0 ]
20+ indices = [None , self . index0 ]
1621 out = torch .ops .aten .index .Tensor (x , indices )
1722 return out
1823
@@ -23,11 +28,14 @@ def forward(self, x):
2328 expected_ops = {torch .ops .aten .index .Tensor },
2429 )
2530
26- def test_index_zero_index_one (self ):
31+ def test_index_zero_index_three_dim (self ):
2732 class TestModule (nn .Module ):
33+ def __init__ (self ):
34+ self .index0 = torch .randint (0 , 1 , (1 , 1 ))
35+ super ().__init__ ()
36+
2837 def forward (self , x ):
29- index0 = torch .randint (0 , 1 , (1 , 1 ))
30- indices = [None , index0 , None ]
38+ indices = [None , self .index0 , None ]
3139 out = torch .ops .aten .index .Tensor (x , indices )
3240 return out
3341
@@ -38,76 +46,101 @@ def forward(self, x):
3846 expected_ops = {torch .ops .aten .index .Tensor },
3947 )
4048
41- def test_index_zero_index_one_index_two (self ):
49+ def test_index_zero_index_one_index_two_three_dim (self ):
4250 class TestModule (nn .Module ):
51+ def __init__ (self ):
52+ self .index0 = torch .randint (0 , 1 , (1 , 1 ))
53+ self .index1 = torch .randint (0 , 1 , (1 , 1 ))
54+ super ().__init__ ()
55+
4356 def forward (self , x ):
44- index0 = torch .randint (0 , 1 , (1 , 1 ))
45- index1 = torch .randint (0 , 1 , (1 , 1 ))
46- indices = [None , index0 , index1 ]
57+ indices = [None , self .index0 , self .index1 ]
4758 out = torch .ops .aten .index .Tensor (x , indices )
4859 return out
4960
5061 input = [torch .randn (2 , 2 , 2 )]
5162 self .run_test (
5263 TestModule (),
5364 input ,
54- expected_ops = {torch .ops .aten .index .Tensor , operator . getitem },
65+ expected_ops = {torch .ops .aten .index .Tensor },
5566 )
5667
57- def test_index_zero_index_one_SD (self ):
68+ def test_index_zero_index_one_four_dim (self ):
5869 class TestModule (nn .Module ):
70+ def __init__ (self ):
71+ self .index0 = torch .tensor ([0 , 0 , 1 , 1 ])
72+ self .index1 = torch .tensor ([0 , 0 , 1 , 1 ])
73+ super ().__init__ ()
74+
5975 def forward (self , x ):
60- index0 = torch .tensor ([0 , 0 , 1 , 1 ])
61- index1 = torch .tensor ([0 , 0 , 1 , 1 ])
62- indices = [None , index0 , index1 , None ]
76+ indices = [None , self .index0 , self .index1 , None ]
6377 out = torch .ops .aten .index .Tensor (x , indices )
6478 return out
6579
6680 input = [torch .randn (2 , 4 , 4 , 2 )]
6781 self .run_test (
6882 TestModule (),
6983 input ,
70- expected_ops = {torch .ops .aten .index .Tensor , operator . getitem },
84+ expected_ops = {torch .ops .aten .index .Tensor },
7185 )
7286
73- def test_index_zero_index_one_SD (self ):
87+ def test_index_zero_index_one_four_dim_SD (self ):
7488 class TestModule (nn .Module ):
89+ def __init__ (self ):
90+ self .index0 = torch .tensor (
91+ [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ]
92+ )
93+ self .index1 = torch .tensor (
94+ [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ]
95+ )
96+ super ().__init__ ()
97+
7598 def forward (self , x ):
76- index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
77- index1 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
78- indices = [None , index0 , index1 , None ]
99+ indices = [None , self .index0 , self .index1 , None ]
79100 out = torch .ops .aten .index .Tensor (x , indices )
80101 return out
81102
82103 input = [torch .randn (2 , 1280 , 8 , 8 )]
83104 self .run_test (
84105 TestModule (),
85106 input ,
86- expected_ops = {torch .ops .aten .index .Tensor , operator . getitem },
107+ expected_ops = {torch .ops .aten .index .Tensor },
87108 )
88109
89- def test_index_zero_index_one_SD (self ):
110+ def test_index_one_SD_unsqueeze_four_dim (self ):
90111 class TestModule (nn .Module ):
112+ def __init__ (self ):
113+ self .index0 = torch .tensor (
114+ [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ]
115+ )
116+ self .index1 = self .index0 .unsqueeze (0 ).T .long ()
117+ super ().__init__ ()
118+
91119 def forward (self , x ):
92- index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
93- index1 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
94- indices = [None , index0 , index1 , None ]
120+ indices = [None , None , self .index1 , self .index1 ]
95121 out = torch .ops .aten .index .Tensor (x , indices )
96122 return out
97123
98124 input = [torch .randn (2 , 1280 , 8 , 8 )]
99125 self .run_test (
100126 TestModule (),
101127 input ,
102- expected_ops = {torch .ops .aten .index .Tensor , operator . getitem },
128+ expected_ops = {torch .ops .aten .index .Tensor },
103129 )
104130
105- def test_index_zero_index_one_SD_unsqueeze (self ):
131+ def test_index_zero_index_one_index_two_SD_unsqueeze_four_dim_broadcast (self ):
106132 class TestModule (nn .Module ):
133+ def __init__ (self ):
134+ self .index0 = torch .tensor (
135+ [0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ]
136+ )
137+ self .index1 = self .index0 .unsqueeze (0 ).T .long ()
138+ super ().__init__ ()
139+
107140 def forward (self , x ):
108141 index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
109142 index1 = index0 .unsqueeze (0 ).T .long ()
110- indices = [None , None , index1 , index1 ]
143+ indices = [None , None , self . index0 , self . index1 ]
111144 out = torch .ops .aten .index .Tensor (x , indices )
112145 return out
113146
@@ -118,16 +151,19 @@ def forward(self, x):
118151 expected_ops = {torch .ops .aten .index .Tensor },
119152 )
120153
121- def test_index_zero_index_one_index_two_SD_unsqueeze (self ):
154+ def test_index_zero_index_one_index_four_dim_non_continuous (self ):
122155 class TestModule (nn .Module ):
156+ def __init__ (self ):
157+ self .index0 = torch .tensor ([0 , 0 , 1 , 1 ])
158+ self .index1 = torch .tensor ([0 , 0 , 1 , 1 ])
159+ super ().__init__ ()
160+
123161 def forward (self , x ):
124- index0 = torch .tensor ([0 , 0 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 4 , 5 , 5 , 6 , 6 , 7 , 7 ])
125- index1 = index0 .unsqueeze (0 ).T .long ()
126- indices = [None , None , index0 , index1 ]
162+ indices = [None , self .index0 , None , self .index1 ]
127163 out = torch .ops .aten .index .Tensor (x , indices )
128164 return out
129165
130- input = [torch .randn (2 , 1280 , 8 , 8 )]
166+ input = [torch .randn (2 , 4 , 4 , 2 )]
131167 self .run_test (
132168 TestModule (),
133169 input ,
0 commit comments