diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4c9480b58748..1fa21607a59a 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -401,6 +401,16 @@ def call_binary_op(op, lhs, rhs): return convert + def _rsub(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + lhs = args[0] + rhs = args[1] + + if isinstance(rhs, (int, float)): + rhs = relax.const(rhs) + + return self.block_builder.emit(relax.op.subtract(rhs, lhs)) + ########## Linear Algebra ########## def _linalg_vector_norm(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c82a5e2b1100..ff343c498f4a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -304,6 +304,8 @@ def create_convert_map( "relu_.default": self._unary_op(relax.op.nn.relu), "round.default": self._round, "rsqrt.default": self._unary_op(relax.op.rsqrt), + "rsub.Tensor": self._rsub, + "rsub.Scalar": self._rsub, "selu.default": self._unary_op(relax.op.nn.selu), "sigmoid.default": self._unary_op(relax.op.sigmoid), "sign.default": self._unary_op(relax.op.sign), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 297529e8bf29..886a23eb1c0f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -699,6 +699,7 @@ def create_convert_map( "pow": self._binary_op(relax.op.power, operator.pow), "or_": self._binary_op(relax.op.bitwise_or, operator.or_), "rshift": self._binary_op(relax.op.right_shift, operator.rshift), + "rsub": self._rsub, "sub": self._binary_op(relax.op.subtract, operator.sub), "truediv": self._binary_op(relax.op.divide, operator.truediv), "xor": self._binary_op(relax.op.bitwise_xor, operator.xor), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 26d3d3f7bde2..0cb00d216fdc 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -899,6 +899,7 @@ def test_binary3(): torch.randn(10, 10, dtype=torch.float32), torch.randn(10, 10, dtype=torch.float32), ) + example_args2 = (torch.randn(10, 10, dtype=torch.float32),) # Max class Max1(Module): @@ -940,6 +941,42 @@ def main( verify_model(Min1(), example_args1, {}, expected_min1) + # RSub + class RSub1(Module): + def forward(self, x, y): + return torch.rsub(x, y) + + class RSub2(Module): + def forward(self, x): + return torch.rsub(x, 5.0) + + @tvm.script.ir_module + class expected_rsub1: + @R.function + def main( + x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_rsub2: + @R.function + def main( + x: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(R.const(5.0, "float32"), x) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(RSub1(), example_args1, {}, expected_rsub1) + verify_model(RSub2(), example_args2, {}, expected_rsub2) + def test_batchnorm2d(): class BatchNorm2d(Module): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index a962de8a3237..4e847be317d4 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1702,6 +1702,45 @@ def main( verify_model(Binary2(op), input_info2, {}, expected_binary2) +# RSub +def test_rsub(): + input_info1 = [([10, 10], "float32"), ([10, 10], "float32")] + input_info2 = [([10, 10], "float32")] + + class RSub1(Module): + def forward(self, x, y): + return torch.rsub(x, y) + + class RSub2(Module): + def forward(self, x): + return torch.rsub(x, 5.0) + + @tvm.script.ir_module + class expected_rsub1: + @R.function + def main( + x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected_rsub2: + @R.function + def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.subtract(R.const(5.0, "float32"), x) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(RSub1(), input_info1, {}, expected_rsub1) + verify_model(RSub2(), input_info2, {}, expected_rsub2) + + def test_size(): input_info = [([1, 3, 10, 10], "float32")]