Skip to content

Commit 6b186b6

Browse files
committed
Adding test case for index
1 parent 4058533 commit 6b186b6

File tree

2 files changed

+132
-83
lines changed

2 files changed

+132
-83
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 113 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
from typing import Optional, cast, Union, Sequence
2-
import tensorrt as trt
1+
from typing import Optional, Sequence, Union, cast
32

43
import numpy as np
4+
import tensorrt as trt
55
from torch.fx.node import Target
66
from torch_tensorrt.dynamo._SourceIR import SourceIR
7-
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
87
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
8+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
99
from torch_tensorrt.fx.converters.converter_utils import (
1010
get_positive_dim,
1111
get_trt_tensor,
1212
has_dynamic_shape,
13+
set_layer_name,
1314
to_numpy,
1415
)
1516
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
16-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1717

1818

1919
def select(
@@ -73,15 +73,15 @@ def index(
7373
source_ir: Optional[SourceIR],
7474
name: str,
7575
input: TRTTensor,
76-
index: Union[TRTTensor, Sequence[TRTTensor]]
76+
index: Union[TRTTensor, Sequence[TRTTensor]],
7777
) -> TRTTensor:
7878
adv_indx_indices = []
7979
tensor_indices = []
8080

8181
for i in len(index):
8282
ind = index[i]
83-
#FIXME: check if the datatype for the indices needs to be casted to INT32
84-
#TRTInterpretor should take care
83+
# FIXME: check if the datatype for the indices needs to be casted to INT32
84+
# TRTInterpretor should take care
8585
adv_indx_indices.append(i)
8686
tensor_indices.append(ind)
8787

@@ -90,7 +90,7 @@ def index(
9090
identity_layer.set_output_type(0, trt.int32)
9191
set_layer_name(identity_layer, target, name + "_index_identity", source_ir)
9292
return identity_layer.get_output(0)
93-
elif (len(tensor_indices) == 1):
93+
elif len(tensor_indices) == 1:
9494
indices_tensor = tensor_indices[0]
9595
gather_layer = network.add_gather(input, indices_tensor, adv_indx_indices[0])
9696
set_layer_name(gather_layer, target, name + "_index_gather", source_ir)
@@ -104,20 +104,22 @@ def index(
104104
input_shape_tensor = input_shape_layer.get_output(0)
105105
dim_tensor_list = []
106106
for i in range(rank):
107-
#check this
108-
dim_tensor_layer = network.add_gather(input_shape_tensor, i ,0)
109-
set_layer_name(input_shape_layer, target, name + "_index_gather_rank", source_ir)
107+
# check this
108+
dim_tensor_layer = network.add_gather(input_shape_tensor, i, 0)
109+
set_layer_name(
110+
input_shape_layer, target, name + "_index_gather_rank", source_ir
111+
)
110112
dim_tensor = dim_tensor_layer.get_output(0)
111113
dim_tensor_list.append(dim_tensor)
112114

113-
#for cases like
114-
#t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
115-
#where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
116-
#for ":"
117-
#Examples: x.shape = (10,20,30,40,50)
118-
#ind_1, ind_2 broadcasted to (2,3,4)
119-
#x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
120-
#x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
115+
# for cases like
116+
# t: [x_1, y_1, y_2, ..., x_m, ..., y_n] -> t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n],
117+
# where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes
118+
# for ":"
119+
# Examples: x.shape = (10,20,30,40,50)
120+
# ind_1, ind_2 broadcasted to (2,3,4)
121+
# x[:, ind_1, ind_2] = 10, 2, 3, 4, 40, 50
122+
# x[:,ind_1, :, ind_2] = 2, 3, 4, 10, 30, 50
121123
transpose_layer = network.add_shuffle(input)
122124
new_order = []
123125
for i in range(adv_indx_count):
@@ -132,36 +134,40 @@ def index(
132134
set_layer_name(transpose_layer, target, name + "_index_transpose", source_ir)
133135
transpose_tensor = transpose_layer.get_output(0)
134136

135-
#Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
137+
# Flatten [x_1, x_2,.......x_m, y_1, y_2,.....y_m]
136138
transpose_tensor_shape = network.add_shape(transpose_tensor)
137139
d0 = 1
138140
d0 = get_trt_tensor(network, d0, "d0_initial")
139141
for i in range(adv_indx_count):
140142
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
141-
set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir)
143+
set_layer_name(
144+
dim_tensor_layer, target, name + "_index_gather_concatOne", source_ir
145+
)
142146
d0_gather = gather_layer.get_output(0)
143147
mult_d0 = convert_binary_elementwise(
144-
network,
145-
target,
146-
source_ir,
147-
name + "index_concatOne_shape",
148-
trt.ElementWisePROD,
149-
mult_d0,
150-
d0_gather,
151-
)
152-
148+
network,
149+
target,
150+
source_ir,
151+
name + "index_concatOne_shape",
152+
trt.ElementWisePROD,
153+
mult_d0,
154+
d0_gather,
155+
)
156+
153157
d1 = 1
154158
d1 = get_trt_tensor(network, d0, "d0_initial")
155159
for i in range(adv_indx_count, rank):
156160
dim_tensor_layer = network.add_gather(transpose_tensor_shape, i, 0)
157-
set_layer_name(dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir)
161+
set_layer_name(
162+
dim_tensor_layer, target, name + "_index_gather_concatTwo", source_ir
163+
)
158164
d1_gather = gather_layer.get_output(0)
159165
mult_d1 = convert_binary_elementwise(
160-
network,
161-
target,
166+
network,
167+
target,
162168
source_ir,
163-
name + "index_concatTwo_shape",
164-
trt.ElementWisePROD,
169+
name + "index_concatTwo_shape",
170+
trt.ElementWisePROD,
165171
mult_d1,
166172
d1_gather,
167173
)
@@ -170,126 +176,150 @@ def index(
170176
concat_tensor = concat_tensor_layer.get_output(0)
171177

172178
reshape_layer = network.add_shuffle(transpose_tensor)
173-
#check this
179+
# check this
174180
reshape_layer.set_input(1, concat_tensor)
175181
flatten_tensor = reshape_layer.get_output(0)
176182

177-
#tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
178-
#// j dimension of input x.
179-
multiplier = get_trt_tensor(network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last")
183+
# tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)), ind_i is input indices[i], x_j is the
184+
# // j dimension of input x.
185+
multiplier = get_trt_tensor(
186+
network, dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], "dim_last"
187+
)
180188
cum_adv_index = tensor_indices[adv_indx_count - 1]
181-
for i in range(adv_indx_count-2, 0):
189+
for i in range(adv_indx_count - 2, 0):
182190
adv_index = convert_binary_elementwise(
183-
network,
184-
target,
191+
network,
192+
target,
185193
source_ir,
186-
name + "index_intermediate",
187-
trt.ElementWisePROD,
194+
name + "index_intermediate",
195+
trt.ElementWisePROD,
188196
multiplier,
189197
tensor_indices[i],
190198
)
191199
cum_adv_index = convert_binary_elementwise(
192-
network,
193-
target,
200+
network,
201+
target,
194202
source_ir,
195-
name + "index_sum_intermediate",
196-
trt.ElementWiseSUM,
203+
name + "index_sum_intermediate",
204+
trt.ElementWiseSUM,
197205
cum_adv_index,
198206
adv_index,
199207
)
200208
multiplier = convert_binary_elementwise(
201-
network,
202-
target,
209+
network,
210+
target,
203211
source_ir,
204-
name + "index_intermediate",
205-
trt.ElementWisePROD,
212+
name + "index_intermediate",
213+
trt.ElementWisePROD,
206214
multiplier,
207215
dim_tensor_list[adv_indx_count[i]],
208216
)
209217

210218
gather_layer_element = network.add_gather(flatten_tensor, cum_adv_index, 0)
211-
set_layer_name(gather_layer_element, target, name + "_index_gather_element", source_ir)
219+
set_layer_name(
220+
gather_layer_element, target, name + "_index_gather_element", source_ir
221+
)
212222
gather_out = gather_layer.get_output(0)
213223

214224
cum_adv_index_shape_tensor = cum_adv_index.add_shape(cum_adv_index_shape_tensor)
215-
#check if all advanced indices are consecutive
225+
# check if all advanced indices are consecutive
216226
concat_tensor_reshape = []
217-
if(adv_indx_count == adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1):
218-
#concat_tensor_reshape_initial = -1
219-
#concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
227+
if (
228+
adv_indx_count
229+
== adv_indx_indices[adv_indx_count - 1] - adv_indx_indices[0] + 1
230+
):
231+
# concat_tensor_reshape_initial = -1
232+
# concat_tensor_reshape_initial_tensor = get_trt_tensor(network, concat_tensor_reshape_initial, "concat_tensor_reshape_initial")
220233
concat_tensor_reshape.append(-1)
221234
for i in range(0, rank):
222235
if i not in adv_indx_indices:
223236
curr_dim = dim_tensor_list[i]
224237
concat_tensor_reshape.append(curr_dim)
225-
238+
226239
concat_tensor_layer = network.add_concatenation(concat_tensor_reshape)
227-
set_layer_name(concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir)
240+
set_layer_name(
241+
concat_tensor_layer, target, name + "_index_Concat_reshape", source_ir
242+
)
228243
concat_tensor = concat_tensor_layer.get_output(0)
229244

230245
regular_index_shuffle_layer = network.add_shuffle(gather_out)
231-
set_layer_name(regular_index_shuffle_layer, target, name + "_index_regular_index", source_ir)
246+
set_layer_name(
247+
regular_index_shuffle_layer,
248+
target,
249+
name + "_index_regular_index",
250+
source_ir,
251+
)
232252
unfold_tensor = regular_index_shuffle_layer.get_output(0)
233253

234254
transpose_advanced_shuffle_layer = network.add_shuffle(unfold_tensor)
235255
new_order = []
236-
for i in range(1, adv_indx_count[0]+1):
256+
for i in range(1, adv_indx_count[0] + 1):
237257
new_order.append(i)
238258
new_order.append(0)
239-
for i in range(adv_indx_indices[0]+1, rank - adv_indx_count):
259+
for i in range(adv_indx_indices[0] + 1, rank - adv_indx_count):
240260
new_order.append(i)
241261

242262
permute_order = trt.Permutation()
243263
permute_order(new_order)
244264
transpose_advanced_shuffle_layer.set_second_transpose(permute_order)
245-
set_layer_name(transpose_advanced_shuffle_layer, target, name + "_index_advanced_shuffle_transpose", source_ir)
265+
set_layer_name(
266+
transpose_advanced_shuffle_layer,
267+
target,
268+
name + "_index_advanced_shuffle_transpose",
269+
source_ir,
270+
)
246271
transpose_tensor = transpose_advanced_shuffle_layer.get_output(0)
247272

248-
#unfold advanced layer
273+
# unfold advanced layer
249274
concat_final_tensor = []
250275
for i in range(0, adv_indx_indices[0]):
251276
current_dim = dim_tensor_list[i]
252277
concat_final_tensor.push_back(curr_dim)
253278

254279
concat_final_tensor.push_back(cum_adv_index_shape_tensor)
255280
for i in range(adv_indx_indices[0], rank):
256-
if(i not in (adv_indx_indices)):
281+
if i not in (adv_indx_indices):
257282
current_dim = dim_tensor_list[i]
258283
concat_final_tensor.append(current_dim)
259-
284+
260285
concat_final_shape_layer = network.add_concatenation(concat_final_tensor)
261-
set_layer_name(concat_final_shape_layer, target, name + "_index_concat_final_shape_layer", source_ir)
286+
set_layer_name(
287+
concat_final_shape_layer,
288+
target,
289+
name + "_index_concat_final_shape_layer",
290+
source_ir,
291+
)
262292
concat_final_tensor = concat_final_shape_layer.get_output(0)
263293

264294
unfold_advanced_shuffle_layer = network.add_shuffle(transpose_tensor)
265-
#check this
295+
# check this
266296
reshape_layer.set_input(1, concat_final_tensor)
267297
reshape_output = reshape_layer.get_output(0)
268-
298+
269299
else:
270-
concat_tensor= []
300+
concat_tensor = []
271301
for i in range(0, rank):
272302
if i not in adv_indx_indices:
273303
curr_dim = dim_tensor_list[i]
274304
concat_tensor.append(curr_dim)
275-
305+
276306
concat_layer = network.add_concatenation(concat_tensor)
277-
set_layer_name(concat_layer, target, name + "_index_concat_final_shape_layer", source_ir)
307+
set_layer_name(
308+
concat_layer,
309+
target,
310+
name + "_index_concat_final_shape_layer",
311+
source_ir,
312+
)
278313
concat_final_tensor = concat_final_shape_layer.get_output(0)
279314

280315
reshape_layer = network.add_shuffle(gather_out)
281316
reshape_layer.setInput(1, concat_final_tensor)
282-
set_layer_name(reshape_layer, target, name + "_index_shuffle_final_shape_layer", source_ir)
317+
set_layer_name(
318+
reshape_layer,
319+
target,
320+
name + "_index_shuffle_final_shape_layer",
321+
source_ir,
322+
)
283323
reshape_output = reshape_layer.get_output(0)
284324

285325
return reshape_output
286-
287-
288-
289-
290-
291-
292-
293-
294-
295-
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestIndexConverter(DispatchTestCase):
10+
def test_index(self):
11+
class TestModule(nn.Module):
12+
def forward(self, x):
13+
input = torch.randn(2, 1280, 8, 8)
14+
index0 = torch.randint(0, 16, (1, 16))
15+
index1 = torch.randint(0, 16, (1, 16))
16+
out = torch.ops.aten.index(None, None, index0, index1)
17+
18+
inputs = [torch.randn(1, 10)]
19+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.index.Tensor})

0 commit comments

Comments
 (0)