Skip to content

Commit 7e2467a

Browse files
authored
[TIR, Relay] improve bfloat16 support (#10112)
* update AMP table to enable ResNet50 conversion * add runtime datatype dispatch for BFloat16 * skip asserts for uint16 for bf16 compatibility * add bf16 cast for the unary intrinsic operators * enable "bf16<-->fp32<-->any dtype" casting * support inconsistent input for bf16 BIOP legalize * add treatments for bfloat16 in if statements * add bfloat16 dtype casts in binary OP * delete unnecessary treatments for bfloat16 * add test for bfloat16 building * code style * restore the modifications in .gitignore * restore the changes to AMP lists * fix typos * fix lint errors * fix typo
1 parent cb7f773 commit 7e2467a

File tree

15 files changed

+177
-95
lines changed

15 files changed

+177
-95
lines changed

include/tvm/tir/op.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,10 +862,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
862862
Span span = Span());
863863

864864
// Intrinsic operators
865-
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
866-
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
867-
static const Op& op = Op::Get("tir." #OpName); \
868-
return tir::Call(x.dtype(), op, {x}, span); \
865+
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
866+
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
867+
static const Op& op = Op::Get("tir." #OpName); \
868+
if (x.dtype().is_bfloat16()) { \
869+
DataType srcType = x.dtype(); \
870+
DataType dstType(kDLFloat, 32, srcType.lanes()); \
871+
PrimExpr castX = tir::Cast(dstType, {x}, span); \
872+
PrimExpr result = tir::Call(dstType, op, {castX}, span); \
873+
return tir::Cast(srcType, {result}, span); \
874+
} else { \
875+
return tir::Call(x.dtype(), op, {x}, span); \
876+
} \
869877
}
870878

871879
TVM_DECLARE_INTRIN_UNARY(exp);

src/arith/rewrite_simplify.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
461461

462462
// x / 2.0 = x * 0.5
463463
if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
464-
ICHECK(op->dtype.is_float() ||
464+
ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() ||
465465
datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
466466
return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
467467
}

src/auto_scheduler/feature.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,14 +246,14 @@ int64_t GetLoopExtent(const ForNode* node) {
246246
// Count math ops in an expr
247247
class MathOpCounter : public StmtExprVisitor {
248248
public:
249-
#define VisitBinary(Type, float_ct, int_ct) \
250-
void VisitExpr_(const Type* op) final { \
251-
if (op->a.dtype().is_float()) { \
252-
float_ct++; \
253-
} else { \
254-
int_ct++; \
255-
} \
256-
StmtExprVisitor::VisitExpr_(op); \
249+
#define VisitBinary(Type, float_ct, int_ct) \
250+
void VisitExpr_(const Type* op) final { \
251+
if (op->a.dtype().is_float() || op->a.dtype().is_bfloat16()) { \
252+
float_ct++; \
253+
} else { \
254+
int_ct++; \
255+
} \
256+
StmtExprVisitor::VisitExpr_(op); \
257257
}
258258

259259
VisitBinary(AddNode, float_addsub, int_addsub);
@@ -299,13 +299,13 @@ class MathOpCounter : public StmtExprVisitor {
299299
effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation;
300300

301301
if (is_pure) {
302-
if (op->dtype.is_float()) {
302+
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
303303
float_math_func++;
304304
} else {
305305
int_math_func++;
306306
}
307307
} else {
308-
if (op->dtype.is_float()) {
308+
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
309309
float_other_func++;
310310
} else {
311311
int_other_func++;

src/autotvm/touch_extractor.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,27 +87,37 @@ class TouchExtractor : public FeatureVisitor {
8787

8888
// arithmetic stats
8989
void VisitExpr_(const AddNode* op) final {
90-
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++;
90+
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
91+
itervar_map[itervar_stack_.back()].add_ct++;
92+
}
9193
FeatureVisitor::VisitExpr_(op);
9294
}
9395

9496
void VisitExpr_(const SubNode* op) final {
95-
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++;
97+
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
98+
itervar_map[itervar_stack_.back()].add_ct++;
99+
}
96100
FeatureVisitor::VisitExpr_(op);
97101
}
98102

99103
void VisitExpr_(const MulNode* op) final {
100-
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++;
104+
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
105+
itervar_map[itervar_stack_.back()].mul_ct++;
106+
}
101107
FeatureVisitor::VisitExpr_(op);
102108
}
103109

104110
void VisitExpr_(const DivNode* op) final {
105-
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++;
111+
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
112+
itervar_map[itervar_stack_.back()].div_ct++;
113+
}
106114
FeatureVisitor::VisitExpr_(op);
107115
}
108116

109117
void VisitExpr_(const ModNode* op) final {
110-
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++;
118+
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
119+
itervar_map[itervar_stack_.back()].div_ct++;
120+
}
111121
FeatureVisitor::VisitExpr_(op);
112122
}
113123

src/contrib/hybrid/codegen_hybrid.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream& os) {
6969
} else if (t.is_int()) {
7070
os << "int";
7171
ICHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
72+
} else if (t.is_bfloat16()) {
73+
os << "bfloat";
74+
ICHECK(t.bits() == 16);
7275
} else {
7376
ICHECK(t.is_uint()) << "Unsupported type " << t;
7477
os << "uint";

src/relay/backend/contrib/codegen_c/codegen_c.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ class CodegenCBase {
363363
dtype = "float";
364364
} else if (runtime::TypeMatch(ttype->dtype, kDLFloat, 16)) {
365365
dtype = "half";
366+
} else if (runtime::TypeMatch(ttype->dtype, kDLBfloat, 16)) {
367+
dtype = "bfloat";
366368
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) {
367369
dtype = "int";
368370
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) {

src/relay/backend/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ inline std::string DType2String(const tvm::DataType dtype) {
359359
os << "int";
360360
} else if (dtype.is_uint()) {
361361
os << "uint";
362+
} else if (dtype.is_bfloat16()) {
363+
os << "bfloat";
362364
} else if ((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) {
363365
os << "custom["
364366
<< (*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator std::string()

src/relay/op/nn/nn.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,8 @@ bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
11651165
<< ", weights shape = " << weights->shape);
11661166
return false;
11671167
}
1168-
if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) {
1168+
if (!(predictions->dtype == weights->dtype &&
1169+
(predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) {
11691170
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
11701171
<< "NLLLossRel: predictions and weights should"
11711172
<< " be of the same floating type.");

src/relay/transforms/pattern_utils.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ namespace relay {
6363
} else if (type == DataType::Float(16)) { \
6464
typedef uint16_t DType; \
6565
{ __VA_ARGS__ } \
66+
} else if (type == DataType::BFloat(16)) { \
67+
typedef uint16_t DType; \
68+
{ __VA_ARGS__ } \
6669
} else if (type == DataType::Int(64)) { \
6770
typedef int64_t DType; \
6871
{ __VA_ARGS__ } \
@@ -259,6 +262,11 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
259262
// storage is uint16_t
260263
*static_cast<DType*>(arr->data) =
261264
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
265+
} else if (dtype == DataType::BFloat(16)) {
266+
// convert to bfloat16
267+
// storage is uint16_t
268+
*static_cast<DType*>(arr->data) =
269+
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(static_cast<float>(value));
262270
} else {
263271
*static_cast<DType*>(arr->data) = value;
264272
}
@@ -286,6 +294,12 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> s
286294
*(static_cast<DType*>(arr->data) + i) =
287295
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
288296
static_cast<float>(value[i]));
297+
} else if (dtype == DataType::BFloat(16)) {
298+
// convert to bfloat16
299+
// storage is uint16_t
300+
*(static_cast<DType*>(arr->data) + i) =
301+
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(
302+
static_cast<float>(value[i]));
289303
} else {
290304
*(static_cast<DType*>(arr->data) + i) = value[i];
291305
}
@@ -314,6 +328,12 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> s
314328
*(static_cast<DType*>(arr->data) + i) =
315329
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
316330
static_cast<float>(value[i]));
331+
} else if (dtype == DataType::BFloat(16)) {
332+
// convert to bfloat16
333+
// storage is uint16_t
334+
*(static_cast<DType*>(arr->data) + i) =
335+
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(
336+
static_cast<float>(value[i]));
317337
} else {
318338
*(static_cast<DType*>(arr->data) + i) = value[i];
319339
}
@@ -417,6 +437,12 @@ static inline dmlc::optional<long double> TryToScalar(const runtime::NDArray& ar
417437
} else if (array->dtype.bits == 64) {
418438
return dmlc::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
419439
}
440+
} else if (array->dtype.code == kDLBfloat) {
441+
if (array->dtype.bits == 16) {
442+
return dmlc::optional<long double>(
443+
__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
444+
reinterpret_cast<uint16_t*>(array->data)[i]));
445+
}
420446
}
421447
return dmlc::optional<long double>();
422448
}

src/runtime/crt/common/packed_func.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ DLDataType String2DLDataType(const char* s) {
4949
} else if (!strncmp(s, "float", 5)) {
5050
t.code = kDLFloat;
5151
scan = s + 5;
52+
} else if (!strncmp(s, "bfloat", 6)) {
53+
t.code = kDLBfloat;
54+
scan = s + 6;
5255
} else if (!strncmp(s, "handle", 6)) {
5356
t.code = kTVMOpaqueHandle;
5457
t.bits = 64; // handle uses 64 bit by default.

0 commit comments

Comments
 (0)