@@ -89,16 +89,14 @@ def index(
8989 # if no, then we need to broadcast
9090
9191 last_index = None
92- broadcast_shape_len = 0
9392 for i , ind in enumerate (index ):
9493 if ind is not None :
9594 _LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
9695 adv_indx_indices .append (i )
9796 # torch.nn.parameter.Parameter=> torch.Tensor
98- ind = get_trt_tensor (network , ind , f"parameter_to_fp32_tensor_ { i } " )
97+ ind = get_trt_tensor (network , ind , name + f"_parameter_to_fp32_tensor_ { i } " )
9998 if last_index is not None :
100- if not (broadcastable (ind , last_index )):
101- assert "The indices should be broadcastable"
99+ assert broadcastable (ind , last_index ), "The indices should be broadcastable!"
102100 last_index = ind
103101 tensor_indices .append (ind )
104102
@@ -128,7 +126,7 @@ def index(
128126
129127 for i in range (rank ):
130128 dim = input_shape [i ]
131- dim_tensor = get_trt_tensor (network , dim , f"individual_dim_{ i } " )
129+ dim_tensor = get_trt_tensor (network , dim , name + f"individual_dim_{ i } " )
132130 # dim_tensor_list is a list of tensors
133131 dim_tensor_list .append (dim_tensor )
134132
@@ -165,8 +163,8 @@ def index(
165163
166164 concat_tensor_layer = network .add_concatenation (
167165 [
168- get_trt_tensor (network , mult_d0 , "d0_shape" ),
169- get_trt_tensor (network , mult_d1 , "d1_shape" ),
166+ get_trt_tensor (network , mult_d0 , name + "d0_shape" ),
167+ get_trt_tensor (network , mult_d1 , name + "d1_shape" ),
170168 ]
171169 )
172170 set_layer_name (concat_tensor_layer , target , name + "_index_Concat" , source_ir )
@@ -181,15 +179,15 @@ def index(
181179 # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
182180 # // j dimension of input x.
183181 multiplier = get_trt_tensor (
184- network , dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]], "dim_last"
182+ network , dim_tensor_list [adv_indx_indices [adv_indx_count - 1 ]], name + "dim_last"
185183 )
186184 cum_adv_index = tensor_indices [adv_indx_count - 1 ]
187185 for i in range (adv_indx_count - 2 , - 1 , - 1 ):
188186 adv_index = convert_binary_elementwise (
189187 network ,
190188 target ,
191189 source_ir ,
192- name + "index_intermediate " ,
190+ name + f"index_intermediate_ { i } " ,
193191 trt .ElementWiseOperation .PROD ,
194192 multiplier ,
195193 tensor_indices [i ],
@@ -198,7 +196,7 @@ def index(
198196 network ,
199197 target ,
200198 source_ir ,
201- name + "index_sum_intermediate " ,
199+ name + f"index_sum_intermediate_ { i } " ,
202200 trt .ElementWiseOperation .SUM ,
203201 cum_adv_index ,
204202 adv_index ,
@@ -207,7 +205,7 @@ def index(
207205 network ,
208206 target ,
209207 source_ir ,
210- name + "index_intermediate " ,
208+ name + f"index_intermediate_xj_ { i } " ,
211209 trt .ElementWiseOperation .PROD ,
212210 multiplier ,
213211 dim_tensor_list [adv_indx_indices [i ]],
@@ -235,7 +233,7 @@ def index(
235233 == adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
236234 ):
237235 _LOGGER .debug (f"The indices are continuous in this case" )
238- concat_tensor_reshape .append (get_trt_tensor (network , - 1 , "dynamic_concat" ))
236+ concat_tensor_reshape .append (get_trt_tensor (network , - 1 , name + "dynamic_concat" ))
239237 for i in range (0 , rank ):
240238 if i not in adv_indx_indices :
241239 curr_dim = dim_tensor_list [i ]
@@ -294,7 +292,7 @@ def index(
294292 set_layer_name (
295293 concat_final_shape_layer ,
296294 target ,
297- name + "_index_concat_final_shape_layer " ,
295+ name + "_index_continuous_concat_final_shape_layer " ,
298296 source_ir ,
299297 )
300298 concat_final_tensor = concat_final_shape_layer .get_output (0 )
@@ -311,17 +309,19 @@ def index(
311309 reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
312310
313311 else :
314- concat_tensor = []
312+ _LOGGER .debug (f"The indices are not continuous in this case" )
313+ concat_final_tensor = []
314+ concat_final_tensor .append (cum_adv_index_shape_tensor )
315315 for i in range (0 , rank ):
316316 if i not in adv_indx_indices :
317317 curr_dim = dim_tensor_list [i ]
318- concat_tensor .append (curr_dim )
318+ concat_final_tensor .append (curr_dim )
319319
320- concat_layer = network .add_concatenation (concat_tensor )
320+ concat_final_shape_layer = network .add_concatenation (concat_final_tensor )
321321 set_layer_name (
322- concat_layer ,
322+ concat_final_shape_layer ,
323323 target ,
324- name + "_index_concat_final_shape_layer " ,
324+ name + "_index_non_continuous_concat_final_shape_layer " ,
325325 source_ir ,
326326 )
327327 concat_final_tensor = concat_final_shape_layer .get_output (0 )
@@ -331,7 +331,7 @@ def index(
331331 set_layer_name (
332332 reshape_layer ,
333333 target ,
334- name + "_index_shuffle_final_shape_layer " ,
334+ name + "_index_non_continuous_shuffle_final_shape_layer " ,
335335 source_ir ,
336336 )
337337 reshape_output = reshape_layer .get_output (0 )
0 commit comments