@@ -81,30 +81,20 @@ def index(
8181 source_ir : Optional [SourceIR ],
8282 name : str ,
8383 input : TRTTensor ,
84- index : Union [
85- TRTTensor ,
86- Sequence [TRTTensor ],
87- np .ndarray ,
88- Sequence [np .ndarray ],
89- torch .Tensor ,
90- Sequence [torch .Tensor ],
91- ],
84+ index : Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]],
9285) -> TRTTensor :
9386 adv_indx_indices = []
9487 tensor_indices = []
95- # _LOGGER.debug(f"The index shape is {index.shape}")
9688 # check if the input is dynamic
9789 dynamic_shape = has_dynamic_shape (input .shape )
9890 # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
9991 # If any is not this flag will be set to False
100- is_numpy = True
101- _LOGGER .debug (f"Checking for the is_numpy flag" )
102- for i , ind in enumerate (index ):
103- if ind is None :
104- continue
105- if not (isinstance (ind , torch .Tensor ) or isinstance (ind , np .ndarray )):
106- is_numpy = False
107- break
92+ _LOGGER .debug (
93+ f"Determining whether aten.index constant-index optimization can be invoked"
94+ )
95+ is_numpy = all (
96+ isinstance (ind , (torch .Tensor , np .ndarray )) for ind in index if ind is not None
97+ )
10898 # here we need to check if all the index are broadcastable
10999 # if no, then we need to broadcast
110100 last_index = None
@@ -117,7 +107,6 @@ def index(
117107 # other cases are kept as TRTTensor
118108 if is_numpy :
119109 ind = to_numpy (ind )
120- is_numpy = True
121110 else :
122111 ind = get_trt_tensor (ctx , ind , name + f"_parameter_to_fp32_tensor_{ i } " )
123112 if last_index is not None :
@@ -156,9 +145,7 @@ def index(
156145 for i in range (rank ):
157146 dim = input_shape [i ]
158147 dim_tensor = get_trt_tensor (ctx , dim , name + f"_individual_dim_{ i } " )
159- # dim_tensor_list is a list of tensors or numpy
160- if is_numpy :
161- dim_list .append (dim )
148+ # dim_tensor_list is a list of tensors
162149 dim_tensor_list .append (dim_tensor )
163150
164151 # for cases like
@@ -211,12 +198,12 @@ def index(
211198 # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
212199 # // j dimension of input x.
213200 if is_numpy :
214- multiplier = dim_list [adv_indx_indices [adv_indx_count - 1 ]]
201+ multiplier = input_shape [adv_indx_indices [adv_indx_count - 1 ]]
215202 cum_adv_index = tensor_indices [adv_indx_count - 1 ]
216203 for i in range (adv_indx_count - 2 , - 1 , - 1 ):
217204 adv_index = multiplier * tensor_indices [i ]
218205 cum_adv_index = cum_adv_index + adv_index
219- multiplier = multiplier * dim_list [adv_indx_indices [i ]]
206+ multiplier = multiplier * input_shape [adv_indx_indices [i ]]
220207 cum_adv_index = get_trt_tensor (
221208 ctx , cum_adv_index , name + f"_index_sum_intermediate"
222209 )
0 commit comments