@@ -862,18 +862,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
862862 Span span = Span());
863863
864864// Intrinsic operators
865- #define TVM_DECLARE_INTRIN_UNARY (OpName ) \
866- inline PrimExpr OpName (PrimExpr x, Span span = Span()) { \
867- static const Op& op = Op::Get (" tir." #OpName); \
868- if (x.dtype ().is_bfloat16 ()) { \
869- DataType srcType = x.dtype (); \
870- DataType dstType (kDLFloat , 32 , srcType .lanes ()); \
871- PrimExpr castX = tir::Cast (dstType , {x}, span); \
872- PrimExpr result = tir::Call (dstType , op, {castX }, span); \
873- return tir::Cast (srcType , {result }, span); \
874- } else { \
875- return tir::Call (x.dtype (), op, {x}, span); \
876- } \
865+ #define TVM_DECLARE_INTRIN_UNARY (OpName ) \
866+ inline PrimExpr OpName (PrimExpr x, Span span = Span()) { \
867+ static const Op& op = Op::Get (" tir." #OpName); \
868+ if (x.dtype ().is_bfloat16 ()) { \
869+ DataType bf16_dtype = x.dtype (); \
870+ DataType fp32_dtype (kDLFloat , 32 , bf16_dtype .lanes ()); \
871+ PrimExpr x_fp32 = tir::Cast (fp32_dtype , {x}, span); \
872+ PrimExpr result_fp32 = tir::Call (fp32_dtype , op, {x_fp32 }, span); \
873+ return tir::Cast (bf16_dtype , {result_fp32 }, span); \
874+ } else { \
875+ return tir::Call (x.dtype (), op, {x}, span); \
876+ } \
877877 }
878878
879879TVM_DECLARE_INTRIN_UNARY (exp);
0 commit comments