Skip to content

Commit ec548eb

Browse files
authored
[Relax][PyTorch] Add support for gather, flip and take ops (#17707)
* Update test_frontend_from_fx.py * Update fx_translator.py * Update base_fx_graph_translator.py * Update base_fx_graph_translator.py * Update base_fx_graph_translator.py * Update test_frontend_from_fx.py
1 parent 8d555a0 commit ec548eb

File tree

3 files changed

+158
-0
lines changed

3 files changed

+158
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,21 @@ def _expand(self, node: fx.Node) -> relax.Var:
847847
broadcast_shape.append(i)
848848
return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape))
849849

850+
def _flip(self, node: fx.Node) -> relax.Var:
851+
x = self.env[node.args[0]]
852+
dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", None)
853+
if isinstance(dims, (list, tuple)) and len(dims) > 0:
854+
dims = dims[0]
855+
elif not isinstance(dims, int):
856+
raise TypeError(f"flip expects an integer axis, but got {type(dims)}: {dims}")
857+
return self.block_builder.emit(relax.op.flip(x, dims))
858+
859+
def _gather(self, node: fx.Node) -> relax.Var:
860+
x = self.env[node.args[0]]
861+
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
862+
index = self.env[node.args[2]]
863+
return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim))
864+
850865
def _permute(self, node: fx.Node) -> relax.Var:
851866
import torch # type: ignore
852867

@@ -921,6 +936,12 @@ def _stack(self, node: fx.Node) -> relax.Var:
921936
s_shape.append(s)
922937
return self.block_builder.emit(relax.op.reshape(cat, s_shape))
923938

939+
def _take(self, node: fx.Node) -> relax.Var:
940+
x = self.env[node.args[0]]
941+
indices = self.env[node.args[1]]
942+
indices = self.block_builder.emit(relax.op.astype(indices, "int32"))
943+
return self.block_builder.emit(relax.op.take(x, indices))
944+
924945
def _tile(self, node: fx.Node) -> relax.Var:
925946
import torch # type: ignore
926947

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,8 @@ def create_convert_map(
733733
"cumsum": self._cumsum,
734734
"expand": self._expand,
735735
"flatten": self._flatten,
736+
"flip": self._flip,
737+
"gather": self._gather,
736738
"permute": self._permute,
737739
"repeat": self._repeat,
738740
"reshape": self._reshape,
@@ -741,6 +743,7 @@ def create_convert_map(
741743
"split": self._split,
742744
"squeeze": self._squeeze,
743745
"stack": self._stack,
746+
"take": self._take,
744747
"tile": self._tile,
745748
"transpose": self._transpose,
746749
"unsqueeze": lambda node: self.block_builder.emit(

tests/python/relax/test_frontend_from_fx.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3903,5 +3903,139 @@ def main(inp_0: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((), dtype="bool")
39033903
verify_model(IsFloatingPoint(), [([2, 3], "float32")], {}, Expected)
39043904

39053905

3906+
def test_gather():
3907+
class Gather0(Module):
3908+
def forward(self, data, indices):
3909+
return torch.gather(data, 0, indices)
3910+
3911+
class Gather1(Module):
3912+
def forward(self, data, indices):
3913+
return torch.gather(data, 1, indices)
3914+
3915+
class Gather2(Module):
3916+
def forward(self, data, indices):
3917+
return torch.gather(data, -1, indices)
3918+
3919+
class Gather3(Module):
3920+
def forward(self, data, indices):
3921+
return torch.gather(data, -2, indices)
3922+
3923+
@tvm.script.ir_module
3924+
class Expected0:
3925+
@R.function
3926+
def main(
3927+
inp_0: R.Tensor((2, 3), dtype="float32"),
3928+
inp_1: R.Tensor((2, 3), dtype="int32"),
3929+
) -> R.Tensor((2, 3), dtype="float32"):
3930+
with R.dataflow():
3931+
lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=0)
3932+
gv: R.Tensor((2, 3), dtype="float32") = lv
3933+
R.output(gv)
3934+
return gv
3935+
3936+
@tvm.script.ir_module
3937+
class Expected1:
3938+
@R.function
3939+
def main(
3940+
inp_0: R.Tensor((2, 3), dtype="float32"),
3941+
inp_1: R.Tensor((2, 3), dtype="int32"),
3942+
) -> R.Tensor((2, 3), dtype="float32"):
3943+
with R.dataflow():
3944+
lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=1)
3945+
gv: R.Tensor((2, 3), dtype="float32") = lv
3946+
R.output(gv)
3947+
return gv
3948+
3949+
@tvm.script.ir_module
3950+
class Expected2:
3951+
@R.function
3952+
def main(
3953+
inp_0: R.Tensor((2, 3), dtype="float32"),
3954+
inp_1: R.Tensor((2, 3), dtype="int32"),
3955+
) -> R.Tensor((2, 3), dtype="float32"):
3956+
with R.dataflow():
3957+
lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=-1)
3958+
gv: R.Tensor((2, 3), dtype="float32") = lv
3959+
R.output(gv)
3960+
return gv
3961+
3962+
@tvm.script.ir_module
3963+
class Expected3:
3964+
@R.function
3965+
def main(
3966+
inp_0: R.Tensor((2, 3), dtype="float32"),
3967+
inp_1: R.Tensor((2, 3), dtype="int32"),
3968+
) -> R.Tensor((2, 3), dtype="float32"):
3969+
with R.dataflow():
3970+
lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=-2)
3971+
gv: R.Tensor((2, 3), dtype="float32") = lv
3972+
R.output(gv)
3973+
return gv
3974+
3975+
verify_model(Gather0(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected0)
3976+
verify_model(Gather1(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected1)
3977+
verify_model(Gather2(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected2)
3978+
verify_model(Gather3(), [([2, 3], "float32"), ([2, 3], "int32")], {}, Expected3)
3979+
3980+
3981+
def test_flip():
3982+
class Flip0(Module):
3983+
def forward(self, data):
3984+
return torch.flip(data, [0])
3985+
3986+
class Flip1(Module):
3987+
def forward(self, data):
3988+
return torch.flip(data, [1])
3989+
3990+
@tvm.script.ir_module
3991+
class Expected0:
3992+
@R.function
3993+
def main(
3994+
inp_0: R.Tensor((2, 2), dtype="float32"),
3995+
) -> R.Tensor((2, 2), dtype="float32"):
3996+
with R.dataflow():
3997+
lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=0)
3998+
gv: R.Tensor((2, 2), dtype="float32") = lv
3999+
R.output(gv)
4000+
return gv
4001+
4002+
@tvm.script.ir_module
4003+
class Expected1:
4004+
@R.function
4005+
def main(
4006+
inp_0: R.Tensor((2, 2), dtype="float32"),
4007+
) -> R.Tensor((2, 2), dtype="float32"):
4008+
with R.dataflow():
4009+
lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=1)
4010+
gv: R.Tensor((2, 2), dtype="float32") = lv
4011+
R.output(gv)
4012+
return gv
4013+
4014+
verify_model(Flip0(), [([2, 2], "float32")], {}, Expected0)
4015+
verify_model(Flip1(), [([2, 2], "float32")], {}, Expected1)
4016+
4017+
4018+
def test_take():
4019+
class Take(Module):
4020+
def forward(self, data, indices):
4021+
return torch.take(data, indices)
4022+
4023+
@tvm.script.ir_module
4024+
class Expected:
4025+
@R.function
4026+
def main(
4027+
inp_0: R.Tensor((5,), dtype="float32"),
4028+
inp_1: R.Tensor((3,), dtype="int32"),
4029+
) -> R.Tensor((3,), dtype="float32"):
4030+
with R.dataflow():
4031+
lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1, "int32")
4032+
lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv)
4033+
gv: R.Tensor((3,), dtype="float32") = lv1
4034+
R.output(gv)
4035+
return gv
4036+
4037+
verify_model(Take(), [([5], "float32"), ([3], "int32")], {}, Expected)
4038+
4039+
39064040
if __name__ == "__main__":
39074041
tvm.testing.main()

0 commit comments

Comments
 (0)