diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index d57d24bf2f77..4c06a3e76f8e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -622,6 +622,7 @@ def create_convert_map( "asinh": self._unary_op(relax.op.asinh), "atan": self._unary_op(relax.op.atan), "atanh": self._unary_op(relax.op.atanh), + "bitwise_not": self._unary_op(relax.op.bitwise_not), "ceil": self._unary_op(relax.op.ceil), "clamp": self._clamp, "cos": self._unary_op(relax.op.cos), @@ -633,19 +634,25 @@ def create_convert_map( "gelu": self._gelu, "hardsigmoid": self._hardsigmoid, "hardswish": self._hardswish, + "isfinite": self._unary_op(relax.op.isfinite), + "isinf": self._unary_op(relax.op.isinf), + "isnan": self._unary_op(relax.op.isnan), "leaky_relu": self._leakyrelu, "log": self._unary_op(relax.op.log), + "logical_not": self._unary_op(relax.op.logical_not), "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), "relu": self._unary_op(relax.op.nn.relu), "round": self._round, "rsqrt": self._unary_op(relax.op.rsqrt), "sigmoid": self._unary_op(relax.op.sigmoid), + "sign": self._unary_op(relax.op.sign), "silu": self._unary_op(relax.op.nn.silu), "sin": self._unary_op(relax.op.sin), "sinh": self._unary_op(relax.op.sinh), "softmax": self._softmax, "sqrt": self._unary_op(relax.op.sqrt), + "square": self._unary_op(relax.op.square), "tan": self._unary_op(relax.op.tan), "tanh": self._unary_op(relax.op.tanh), "tril_": self._inplace_tril_triu(relax.op.tril), diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 446c4149fdde..19bc15b19216 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1938,6 +1938,7 @@ def main( (torch.asinh, R.asinh), (torch.atan, R.atan), (torch.atanh, R.atanh), + (torch.bitwise_not, R.bitwise_not), (torch.ceil, R.ceil), (torch.cos, R.cos), (torch.cosh, R.cosh), @@ -1950,7 +1951,9 @@ def main( (torch.rsqrt, R.rsqrt), (torch.sin, R.sin), (torch.sinh, R.sinh), + (torch.sign, R.sign), (torch.sqrt, R.sqrt), + (torch.square, R.square), (torch.tan, R.tan), ] @@ -2150,6 +2153,25 @@ def main( verify_model(Hardswish(), input_info, {}, expected_hardswish) verify_model(Hardswish2(), input_info, {}, expected_hardswish) + # logical_not + class LogicalNot(Module): + def forward(self, input): + return torch.logical_not(input) + + @tvm.script.ir_module + class expected_logical_not: + @R.function + def main( + inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.logical_not(inp_0) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(LogicalNot(), input_info, {}, expected_logical_not) + # log_softmax class LogSoftmax(Module): def __init__(self): @@ -2179,6 +2201,63 @@ def main( verify_model(LogSoftmax(), input_info, {}, expected_log_softmax) verify_model(LogSoftmax2(), input_info, {}, expected_log_softmax) + # isfinite + class IsFinite(Module): + def forward(self, input): + return torch.isfinite(input) + + @tvm.script.ir_module + class expected_isfinite: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.isfinite(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + verify_model(IsFinite(), input_info, {}, expected_isfinite) + + # isinf + class IsInf(Module): + def forward(self, input): + return torch.isinf(input) + + @tvm.script.ir_module + class expected_isinf: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.isinf(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + verify_model(IsInf(), input_info, {}, expected_isinf) + + # isnan + class IsNan(Module): + def forward(self, input): + return torch.isnan(input) + + @tvm.script.ir_module + class expected_isnan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.isnan(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + verify_model(IsNan(), input_info, {}, expected_isnan) + # relu class ReLU0(Module): def __init__(self):