Skip to content

Commit a9a6272

Browse files
chohk88peri044
authored andcommitted
feat: support aten.index_select converter (#2710)
1 parent 4314fbc commit a9a6272

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2845,3 +2845,28 @@ def aten_ops_roll(
28452845
args[1],
28462846
args_bounds_check(args, 2, []),
28472847
)
2848+
2849+
2850+
@dynamo_tensorrt_converter(torch.ops.aten.index_select.default)
2851+
@enforce_tensor_types(
2852+
{
2853+
0: (TRTTensor,),
2854+
2: (TRTTensor,),
2855+
}
2856+
)
2857+
def aten_ops_index_select(
2858+
ctx: ConversionContext,
2859+
target: Target,
2860+
args: Tuple[Argument, ...],
2861+
kwargs: Dict[str, Argument],
2862+
name: str,
2863+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2864+
return impl.select.index_select(
2865+
ctx,
2866+
target,
2867+
SourceIR.ATEN,
2868+
name,
2869+
args[0],
2870+
args[1],
2871+
args[2],
2872+
)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestIndexSelectConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("1d_input", (10,), 0, (1,)),
13+
("2d_input_dim_0", (10, 3), 0, (0, 2)),
14+
("2d_input_dim_1", (5, 10), 1, (1, 2, 3)),
15+
("2d_input_dim_-2", (5, 10), -2, (1, 2, 3)),
16+
("3d_input_dim_0", (10, 5, 10), 0, (0, 5)),
17+
("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)),
18+
("3d_input_dim_-1", (10, 5, 10), -1, (3, 3, 4)),
19+
("3d_input_dim_-3", (10, 5, 10), -3, (5, 3, 4)),
20+
]
21+
)
22+
def test_index_select(self, _, source_shape, dim, indices_val):
23+
class TestIndexSelect(torch.nn.Module):
24+
def forward(self, source_tensor, indices_tensor):
25+
return torch.ops.aten.index_select.default(
26+
source_tensor, dim, indices_tensor
27+
)
28+
29+
input = [
30+
torch.randn(*source_shape, dtype=torch.float32),
31+
torch.tensor([*indices_val], dtype=torch.int32),
32+
]
33+
34+
self.run_test(
35+
TestIndexSelect(),
36+
input,
37+
)
38+
39+
40+
if __name__ == "__main__":
41+
run_tests()

0 commit comments

Comments
 (0)