Skip to content

Commit 1d78f5b

Browse files
committed
refine the code style (apache#10112)
1 parent adcf199 commit 1d78f5b

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

include/tvm/tir/op.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -862,18 +862,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
862862
Span span = Span());
863863

864864
// Intrinsic operators
865-
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
866-
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
867-
static const Op& op = Op::Get("tir." #OpName); \
868-
if (x.dtype().is_bfloat16()) { \
869-
DataType srcType = x.dtype(); \
870-
DataType dstType(kDLFloat, 32, srcType.lanes()); \
871-
PrimExpr castX = tir::Cast(dstType, {x}, span); \
872-
PrimExpr result = tir::Call(dstType, op, {castX}, span); \
873-
return tir::Cast(srcType, {result}, span); \
874-
} else { \
875-
return tir::Call(x.dtype(), op, {x}, span); \
876-
} \
865+
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
866+
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
867+
static const Op& op = Op::Get("tir." #OpName); \
868+
if (x.dtype().is_bfloat16()) { \
869+
DataType bf16_dtype = x.dtype(); \
870+
DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \
871+
PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \
872+
PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \
873+
return tir::Cast(bf16_dtype, {result_fp32}, span); \
874+
} else { \
875+
return tir::Call(x.dtype(), op, {x}, span); \
876+
} \
877877
}
878878

879879
TVM_DECLARE_INTRIN_UNARY(exp);

0 commit comments

Comments
 (0)