diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 9a912bbb6b63..e873f6485887 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -2116,6 +2116,8 @@ def exp(x): The result. """ x = tir.convert(x) + if "int" in x.dtype: + x = tir.Cast("float64", x) return call_intrin(x.dtype, "tir.exp", x)