Skip to content

Commit 99defd2

Browse files
authored
[Relax][PyTorch] Add support for torch.repeat (#17304)
* add test * add support for torch.repeat * remove debug print
1 parent bf7bbef commit 99defd2

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,14 @@ def _squeeze(self, node: fx.node.Node) -> relax.Var:
640640
dim = None
641641
return self.block_builder.emit(relax.op.squeeze(x, dim))
642642

643+
def _repeat(self, node: fx.node.Node) -> relax.Var:
644+
import torch # type: ignore
645+
646+
args = self.retrieve_args(node)
647+
if isinstance(args[1], (torch.Size, tuple, list)):
648+
return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1])))
649+
return self.block_builder.emit(relax.op.tile(args[0], args[1:]))
650+
643651
def _tile(self, node: fx.node.Node) -> relax.Var:
644652
import torch # type: ignore
645653

@@ -1484,6 +1492,7 @@ def create_convert_map(self):
14841492
"expand": self._expand,
14851493
"flatten": self._flatten,
14861494
"permute": self._permute,
1495+
"repeat": self._repeat,
14871496
"reshape": self._reshape,
14881497
"split": self._split,
14891498
"tile": self._tile,

tests/python/relax/test_frontend_from_fx.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3311,6 +3311,42 @@ def main(
33113311
verify_model(Transpose(), input_info, {}, expected1)
33123312

33133313

3314+
def test_repeat():
3315+
class Tile1(Module):
3316+
def forward(self, x: torch.Tensor):
3317+
return x.repeat(2)
3318+
3319+
class Tile2(Module):
3320+
def forward(self, x: torch.Tensor):
3321+
return x.repeat(4, 2)
3322+
3323+
@tvm.script.ir_module
3324+
class expected1:
3325+
@R.function
3326+
def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((6,), dtype="float32"):
3327+
# block 0
3328+
with R.dataflow():
3329+
lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2)
3330+
gv: R.Tensor((6,), dtype="float32") = lv
3331+
R.output(gv)
3332+
return gv
3333+
3334+
@tvm.script.ir_module
3335+
class expected2:
3336+
@R.function
3337+
def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float32"):
3338+
# block 0
3339+
with R.dataflow():
3340+
lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
3341+
gv: R.Tensor((4, 6), dtype="float32") = lv
3342+
R.output(gv)
3343+
return gv
3344+
3345+
verify_model(Tile1(), [([3], "float32")], {}, expected1)
3346+
verify_model(Tile2(), [([1, 3], "float32")], {}, expected2)
3347+
verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2)
3348+
3349+
33143350
def test_view():
33153351
input_info = [([1, 2, 3, 4], "float32")]
33163352

0 commit comments

Comments
 (0)