1- from typing import Optional , cast , Union , Sequence
2- import tensorrt as trt
1+ from typing import Optional , Sequence , Union , cast
32
43import numpy as np
4+ import tensorrt as trt
55from torch .fx .node import Target
66from torch_tensorrt .dynamo ._SourceIR import SourceIR
7- from torch_tensorrt .dynamo .conversion .impl .shape import get_shape_with_dynamic_shape
87from torch_tensorrt .dynamo .conversion .impl .elementwise import convert_binary_elementwise
8+ from torch_tensorrt .dynamo .conversion .impl .shape import get_shape_with_dynamic_shape
99from torch_tensorrt .fx .converters .converter_utils import (
1010 get_positive_dim ,
1111 get_trt_tensor ,
1212 has_dynamic_shape ,
13+ set_layer_name ,
1314 to_numpy ,
1415)
1516from torch_tensorrt .fx .types import Shape , TRTNetwork , TRTTensor
16- from torch_tensorrt .fx .converters .converter_utils import set_layer_name
1717
1818
1919def select (
@@ -73,15 +73,15 @@ def index(
7373 source_ir : Optional [SourceIR ],
7474 name : str ,
7575 input : TRTTensor ,
76- index : Union [TRTTensor , Sequence [TRTTensor ]]
76+ index : Union [TRTTensor , Sequence [TRTTensor ]],
7777) -> TRTTensor :
7878 adv_indx_indices = []
7979 tensor_indices = []
8080
8181 for i in len (index ):
8282 ind = index [i ]
83- #FIXME: check if the datatype for the indices needs to be casted to INT32
84- #TRTInterpretor should take care
83+ # FIXME: check if the datatype for the indices needs to be casted to INT32
84+ # TRTInterpretor should take care
8585 adv_indx_indices .append (i )
8686 tensor_indices .append (ind )
8787
@@ -90,7 +90,7 @@ def index(
9090 identity_layer .set_output_type (0 , trt .int32 )
9191 set_layer_name (identity_layer , target , name + "_index_identity" , source_ir )
9292 return identity_layer .get_output (0 )
93- elif ( len (tensor_indices ) == 1 ) :
93+ elif len (tensor_indices ) == 1 :
9494 indices_tensor = tensor_indices [0 ]
9595 gather_layer = network .add_gather (input , indices_tensor , adv_indx_indices [0 ])
9696 set_layer_name (gather_layer , target , name + "_index_gather" , source_ir )
@@ -104,20 +104,22 @@ def index(
104104 input_shape_tensor = input_shape_layer .get_output (0 )
105105 dim_tensor_list = []
106106 for i in range (rank ):
107- #check this
108- dim_tensor_layer = network .add_gather (input_shape_tensor , i ,0 )
109- set_layer_name (input_shape_layer , target , name + "_index_gather_rank" , source_ir )
107+ # check this
108+ dim_tensor_layer = network .add_gather (input_shape_tensor , i , 0 )
109+ set_layer_name (
110+ input_shape_layer , target , name + "_index_gather_rank" , source_ir
111+ )
110112 dim_tensor = dim_tensor_layer .get_output (0 )
111113 dim_tensor_list .append (dim_tensor )
112114
113- #for cases like
114- #t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
115- #where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
116- #for ":"
117- #Examples: x.shape = (10,20,30,40,50)
118- #ind_1, ind_2 broadcasted to (2,3,4)
119- #x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
120- #x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
115+ # for cases like
116+ # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
117+ # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
118+ # for ":"
119+ # Examples: x.shape = (10,20,30,40,50)
120+ # ind_1, ind_2 broadcasted to (2,3,4)
121+ # x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
122+ # x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
121123 transpose_layer = network .add_shuffle (input )
122124 new_order = []
123125 for i in range (adv_indx_count ):
@@ -132,36 +134,40 @@ def index(
132134 set_layer_name (transpose_layer , target , name + "_index_transpose" , source_ir )
133135 transpose_tensor = transpose_layer .get_output (0 )
134136
135- #Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
137+ # Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
136138 transpose_tensor_shape = network .add_shape (transpose_tensor )
137139 d0 = 1
138140 d0 = get_trt_tensor (network , d0 , "d0_initial" )
139141 for i in range (adv_indx_count ):
140142 dim_tensor_layer = network .add_gather (transpose_tensor_shape , i , 0 )
141- set_layer_name (dim_tensor_layer , target , name + "_index_gather_concatOne" , source_ir )
143+ set_layer_name (
144+ dim_tensor_layer , target , name + "_index_gather_concatOne" , source_ir
145+ )
142146 d0_gather = gather_layer .get_output (0 )
143147 mult_d0 = convert_binary_elementwise (
144- network ,
145- target ,
146- source_ir ,
147- name + "index_concatOne_shape" ,
148- trt .ElementWisePROD ,
149- mult_d0 ,
150- d0_gather ,
151- )
152-
148+ network ,
149+ target ,
150+ source_ir ,
151+ name + "index_concatOne_shape" ,
152+ trt .ElementWisePROD ,
153+ mult_d0 ,
154+ d0_gather ,
155+ )
156+
153157 d1 = 1
154158 d1 = get_trt_tensor (network , d0 , "d0_initial" )
155159 for i in range (adv_indx_count , rank ):
156160 dim_tensor_layer = network .add_gather (transpose_tensor_shape , i , 0 )
157- set_layer_name (dim_tensor_layer , target , name + "_index_gather_concatTwo" , source_ir )
161+ set_layer_name (
162+ dim_tensor_layer , target , name + "_index_gather_concatTwo" , source_ir
163+ )
158164 d1_gather = gather_layer .get_output (0 )
159165 mult_d1 = convert_binary_elementwise (
160- network ,
161- target ,
166+ network ,
167+ target ,
162168 source_ir ,
163- name + "index_concatTwo_shape" ,
164- trt .ElementWisePROD ,
169+ name + "index_concatTwo_shape" ,
170+ trt .ElementWisePROD ,
165171 mult_d1 ,
166172 d1_gather ,
167173 )
@@ -170,126 +176,150 @@ def index(
170176 concat_tensor = concat_tensor_layer .get_output (0 )
171177
172178 reshape_layer = network .add_shuffle (transpose_tensor )
173- #check this
179+ # check this
174180 reshape_layer .set_input (1 , concat_tensor )
175181 flatten_tensor = reshape_layer .get_output (0 )
176182
177- #tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
178- #// j dimension of input x.
179- multiplier = get_trt_tensor (network , dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]], "dim_last" )
183+ # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
184+ # // j dimension of input x.
185+ multiplier = get_trt_tensor (
186+ network , dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]], "dim_last"
187+ )
180188 cum_adv_index = tensor_indices [adv_indx_count - 1 ]
181- for i in range (adv_indx_count - 2 , 0 ):
189+ for i in range (adv_indx_count - 2 , 0 ):
182190 adv_index = convert_binary_elementwise (
183- network ,
184- target ,
191+ network ,
192+ target ,
185193 source_ir ,
186- name + "index_intermediate" ,
187- trt .ElementWisePROD ,
194+ name + "index_intermediate" ,
195+ trt .ElementWisePROD ,
188196 multiplier ,
189197 tensor_indices [i ],
190198 )
191199 cum_adv_index = convert_binary_elementwise (
192- network ,
193- target ,
200+ network ,
201+ target ,
194202 source_ir ,
195- name + "index_sum_intermediate" ,
196- trt .ElementWiseSUM ,
203+ name + "index_sum_intermediate" ,
204+ trt .ElementWiseSUM ,
197205 cum_adv_index ,
198206 adv_index ,
199207 )
200208 multiplier = convert_binary_elementwise (
201- network ,
202- target ,
209+ network ,
210+ target ,
203211 source_ir ,
204- name + "index_intermediate" ,
205- trt .ElementWisePROD ,
212+ name + "index_intermediate" ,
213+ trt .ElementWisePROD ,
206214 multiplier ,
207215 dim_tensor_list [adv_indx_count [i ]],
208216 )
209217
210218 gather_layer_element = network .add_gather (flatten_tensor , cum_adv_index , 0 )
211- set_layer_name (gather_layer_element , target , name + "_index_gather_element" , source_ir )
219+ set_layer_name (
220+ gather_layer_element , target , name + "_index_gather_element" , source_ir
221+ )
212222 gather_out = gather_layer .get_output (0 )
213223
214224 cum_adv_index_shape_tensor = cum_adv_index .add_shape (cum_adv_index_shape_tensor )
215- #check if all advanced indices are consecutive
225+ # check if all advanced indices are consecutive
216226 concat_tensor_reshape = []
217- if (adv_indx_count == adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1 ):
218- #concat_tensor_reshape_initial = -1
219- #concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
227+ if (
228+ adv_indx_count
229+ == adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
230+ ):
231+ # concat_tensor_reshape_initial = -1
232+ # concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
220233 concat_tensor_reshape .append (- 1 )
221234 for i in range (0 , rank ):
222235 if i not in adv_indx_indices :
223236 curr_dim = dim_tensor_list [i ]
224237 concat_tensor_reshape .append (curr_dim )
225-
238+
226239 concat_tensor_layer = network .add_concatenation (concat_tensor_reshape )
227- set_layer_name (concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir )
240+ set_layer_name (
241+ concat_tensor_layer , target , name + "_index_Concat_reshape" , source_ir
242+ )
228243 concat_tensor = concat_tensor_layer .get_output (0 )
229244
230245 regular_index_shuffle_layer = network .add_shuffle (gather_out )
231- set_layer_name (regular_index_shuffle_layer , target , name + "_index_regular_index" , source_ir )
246+ set_layer_name (
247+ regular_index_shuffle_layer ,
248+ target ,
249+ name + "_index_regular_index" ,
250+ source_ir ,
251+ )
232252 unfold_tensor = regular_index_shuffle_layer .get_output (0 )
233253
234254 transpose_advanced_shuffle_layer = network .add_shuffle (unfold_tensor )
235255 new_order = []
236- for i in range (1 , adv_indx_count [0 ]+ 1 ):
256+ for i in range (1 , adv_indx_count [0 ] + 1 ):
237257 new_order .append (i )
238258 new_order .append (0 )
239- for i in range (adv_indx_indices [0 ]+ 1 , rank - adv_indx_count ):
259+ for i in range (adv_indx_indices [0 ] + 1 , rank - adv_indx_count ):
240260 new_order .append (i )
241261
242262 permute_order = trt .Permutation ()
243263 permute_order (new_order )
244264 transpose_advanced_shuffle_layer .set_second_transpose (permute_order )
245- set_layer_name (transpose_advanced_shuffle_layer , target , name + "_index_advanced_shuffle_transpose" , source_ir )
265+ set_layer_name (
266+ transpose_advanced_shuffle_layer ,
267+ target ,
268+ name + "_index_advanced_shuffle_transpose" ,
269+ source_ir ,
270+ )
246271 transpose_tensor = transpose_advanced_shuffle_layer .get_output (0 )
247272
248- #unfold advanced layer
273+ # unfold advanced layer
249274 concat_final_tensor = []
250275 for i in range (0 , adv_indx_indices [0 ]):
251276 current_dim = dim_tensor_list [i ]
252277 concat_final_tensor .push_back (curr_dim )
253278
254279 concat_final_tensor .push_back (cum_adv_index_shape_tensor )
255280 for i in range (adv_indx_indices [0 ], rank ):
256- if ( i not in (adv_indx_indices ) ):
281+ if i not in (adv_indx_indices ):
257282 current_dim = dim_tensor_list [i ]
258283 concat_final_tensor .append (current_dim )
259-
284+
260285 concat_final_shape_layer = network .add_concatenation (concat_final_tensor )
261- set_layer_name (concat_final_shape_layer , target , name + "_index_concat_final_shape_layer" , source_ir )
286+ set_layer_name (
287+ concat_final_shape_layer ,
288+ target ,
289+ name + "_index_concat_final_shape_layer" ,
290+ source_ir ,
291+ )
262292 concat_final_tensor = concat_final_shape_layer .get_output (0 )
263293
264294 unfold_advanced_shuffle_layer = network .add_shuffle (transpose_tensor )
265- #check this
295+ # check this
266296 reshape_layer .set_input (1 , concat_final_tensor )
267297 reshape_output = reshape_layer .get_output (0 )
268-
298+
269299 else :
270- concat_tensor = []
300+ concat_tensor = []
271301 for i in range (0 , rank ):
272302 if i not in adv_indx_indices :
273303 curr_dim = dim_tensor_list [i ]
274304 concat_tensor .append (curr_dim )
275-
305+
276306 concat_layer = network .add_concatenation (concat_tensor )
277- set_layer_name (concat_layer , target , name + "_index_concat_final_shape_layer" , source_ir )
307+ set_layer_name (
308+ concat_layer ,
309+ target ,
310+ name + "_index_concat_final_shape_layer" ,
311+ source_ir ,
312+ )
278313 concat_final_tensor = concat_final_shape_layer .get_output (0 )
279314
280315 reshape_layer = network .add_shuffle (gather_out )
281316 reshape_layer .setInput (1 , concat_final_tensor )
282- set_layer_name (reshape_layer , target , name + "_index_shuffle_final_shape_layer" , source_ir )
317+ set_layer_name (
318+ reshape_layer ,
319+ target ,
320+ name + "_index_shuffle_final_shape_layer" ,
321+ source_ir ,
322+ )
283323 reshape_output = reshape_layer .get_output (0 )
284324
285325 return reshape_output
286-
287-
288-
289-
290-
291-
292-
293-
294-
295-
0 commit comments