Skip to content

Commit 5d9440a

Browse files
committed
[TIR] Refactor BF16Legalize
This PR refactors BF16Legalize to enable more f32 computations. We also split the BF16Legalize into two steps. - BF16ComputeLegalize changes all computation to f32 while keeping the external BF16 storages. - BF16StorageLegalize changes all storage to u16. Now BF16 kernels accept tvm.nd.array that are created as bfloat16 type.
1 parent 0d0d2f0 commit 5d9440a

File tree

15 files changed

+613
-423
lines changed

15 files changed

+613
-423
lines changed

include/tvm/tir/transform.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,11 +337,17 @@ TVM_DLL Pass CombineContextCall();
337337
TVM_DLL Pass NarrowDataType(int target_bits);
338338

339339
/*!
340-
* \brief Legalize bf16 typed Ops. Add a cast to fp32
340+
* \brief Legalize bf16 compute Ops. Add a cast to fp32
341341
* before Ops, then add a cast back to bf16.
342342
* \return The pass.
343343
*/
344-
TVM_DLL Pass BF16Legalize();
344+
TVM_DLL Pass BF16ComputeLegalize();
345+
346+
/*!
347+
* \brief Legalize bf16 storage types to u16.
348+
* \return The pass.
349+
*/
350+
TVM_DLL Pass BF16StorageLegalize();
345351

346352
/*!
347353
* \brief Rewrite the pointer content type of arguments,

include/tvm/topi/elemwise.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast",
310310
inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor",
311311
std::string tag = kElementWise) {
312312
return compute(
313-
x->shape,
314-
[&](const Array<Var>& i) {
315-
return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)});
316-
},
317-
name, tag);
313+
x->shape, [&](const Array<Var>& i) { return reinterpret(type, x(i)); }, name, tag);
318314
}
319315

320316
/*!

python/tvm/tir/transform/transform.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -286,59 +286,26 @@ def RemoveStoreUndef():
286286
return _ffi_api.RemoveStoreUndef() # type: ignore
287287

288288

289-
def BF16Legalize():
290-
"""Legalize bf16 typed Ops.
291-
Runs BF16Promote, BF16CastElimination and BF16TypeLowering
289+
def BF16ComputeLegalize():
290+
"""Legalize bf16 compute Ops.
292291
293292
Returns
294293
-------
295294
fpass : tvm.transform.Pass
296295
The result pass
297296
"""
298-
return _ffi_api.BF16Legalize() # type: ignore
297+
return _ffi_api.BF16ComputeLegalize() # type: ignore
299298

300299

301-
def BF16Promote():
302-
"""Promote bf16 to fp32. Add a cast to fp32
303-
before Ops, then add a cast back to bf16.
300+
def BF16StorageLegalize():
301+
"""Legalize bf16 storage types to u16.
304302
305303
Returns
306304
-------
307305
fpass : tvm.transform.Pass
308306
The result pass
309307
"""
310-
return _ffi_api.BF16Promote() # type: ignore
311-
312-
313-
def BF16CastElimination():
314-
"""Eliminate verbose casting between fp32 and bf16
315-
Checks if the AST has the pattern:
316-
castto32(castto16(some_fp32_op(...)))
317-
The verbose casting is generated by BF16Promote for multiple
318-
bf16 Ops in a row. e.g.:
319-
X[i] + Y[i] + T[i] =>
320-
bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
321-
After this pass:
322-
bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
323-
324-
Returns
325-
-------
326-
fpass : tvm.transform.Pass
327-
The result pass
328-
"""
329-
return _ffi_api.BF16CastElimination() # type: ignore
330-
331-
332-
def BF16TypeLowering():
333-
"""Replace all bf16 type with uint16. Also lower the casting
334-
between fp32 and bf16
335-
336-
Returns
337-
-------
338-
fpass : tvm.transform.Pass
339-
The result pass
340-
"""
341-
return _ffi_api.BF16TypeLowering() # type: ignore
308+
return _ffi_api.BF16StorageLegalize() # type: ignore
342309

343310

344311
def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False):

src/driver/driver_api.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
218218
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
219219
pass_list.push_back(tir::transform::LowerOpaqueBlock());
220220
pass_list.push_back(tir::transform::FlattenBuffer());
221-
pass_list.push_back(tir::transform::BF16Legalize());
221+
pass_list.push_back(tir::transform::BF16ComputeLegalize());
222222
pass_list.push_back(tir::transform::NarrowDataType(32));
223223
pass_list.push_back(tir::transform::Simplify());
224224

@@ -605,6 +605,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
605605
} else {
606606
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
607607
}
608+
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
608609
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
609610

610611
return transform::Sequential(mixed_pass_list);

src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
138138
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
139139
pass_list.push_back(tir::transform::LowerOpaqueBlock());
140140
pass_list.push_back(tir::transform::FlattenBuffer());
141-
pass_list.push_back(tir::transform::BF16Legalize());
141+
pass_list.push_back(tir::transform::BF16ComputeLegalize());
142142
pass_list.push_back(tir::transform::NarrowDataType(32));
143143
pass_list.push_back(tir::transform::Simplify());
144144
pass_list.push_back(tir::transform::InjectVirtualThread());

src/meta_schedule/postproc/verify_gpu_code.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class VerifyGPUCodeNode : public PostprocNode {
169169
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
170170
pass_list.push_back(tir::transform::LowerOpaqueBlock());
171171
pass_list.push_back(tir::transform::FlattenBuffer());
172-
pass_list.push_back(tir::transform::BF16Legalize());
172+
pass_list.push_back(tir::transform::BF16ComputeLegalize());
173173
pass_list.push_back(tir::transform::NarrowDataType(32));
174174
pass_list.push_back(tir::transform::Simplify());
175175
// Phase 2

src/target/codegen.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ runtime::Module Build(IRModule mod, Target target) {
4646
.value()) {
4747
mod = tir::transform::SkipAssert()(mod);
4848
}
49-
5049
auto target_attr_map = tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
5150
if (target_attr_map.count(target->kind)) {
5251
return target_attr_map[target->kind](mod, target);

src/target/llvm/codegen_llvm.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,10 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va
828828
llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
829829
llvm::Type* target = DTypeToLLVMType(to);
830830
if (value->getType() == target) return value;
831+
// TODO(tvm-team): consider add native support
832+
ICHECK(!from.is_bfloat16()) << "BF16 needs to be storaged lowered first";
833+
ICHECK(!to.is_bfloat16()) << "BF16 needs to be storaged lowered first";
834+
831835
if (to.is_handle()) {
832836
return builder_->CreateBitCast(value, target);
833837
} else if (to.is_uint() && to.bits() == 1) {

src/target/llvm/llvm_module.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,6 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) {
325325
if (tm->getTargetTriple().isOSDarwin()) {
326326
module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
327327
}
328-
329328
std::string verify_errors_storage;
330329
llvm::raw_string_ostream verify_errors(verify_errors_storage);
331330
LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors))

src/tir/op/op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,8 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) {
324324
// reinterpret
325325
PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
326326
if (value.dtype() == t) return value;
327+
ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes())
328+
<< "Bitcast requires size match " << t << " vs " << value.dtype();
327329
return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
328330
}
329331

0 commit comments

Comments
 (0)