Skip to content

Commit feb1043

Browse files
[TIR][CUDA] Add native FP8 support to codegen (#16548)
* [TIR][CUDA] Add native FP8 support to codegen Adds native FP8 type support for CUDA. The e4m3/e5m2 struct types provide explicit type conversions that target hardware native conversion ops. * Conditionally run Storage and Compute legalization for targets that don't support FP8. This could be changed to only support conversion operators and do legalization on any compute operations other than builtin wmma calls. * Implement support for float16x4 (half4) for use with e4m3_float8x4 (__nv_fp8x4_e4m3) * Add test for e4m3 <-> half conversion which lowers to ptx intrins. * Introduce half4 and support native fp8 vector types (1, 2, 4), and conversion between float and half vector types with equal lanes * Only cast to half2 for vector loads/stores of non native half struct types (lanes > 4). * Test e4m3 x4 vector quant/dequant --------- Co-authored-by: Joseph McMahan <[email protected]>
1 parent 45df124 commit feb1043

File tree

8 files changed

+957
-45
lines changed

8 files changed

+957
-45
lines changed

include/tvm/tir/transform.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,17 +398,19 @@ TVM_DLL Pass ForceNarrowIndexToInt32();
398398
/*!
399399
* \brief Legalize bf16 compute Ops. Add a cast to fp32
400400
* before Ops, then add a cast back to bf16.
401+
* \param target The target used for checking native bf16 support
401402
* \return The pass.
402403
*/
403404
TVM_DLL Pass BF16ComputeLegalize();
404405

405406
/*!
406407
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
407408
* before Ops, then add a cast back to fp8.
409+
* \param target The target used for checking native fp8 support
408410
* \param promote_dtype_str The data type used for type promotion, defaults to float16
409411
* \return The pass.
410412
*/
411-
TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
413+
TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = "float16");
412414

413415
/*!
414416
* \brief Legalize bf16 storage types to u16.
@@ -420,7 +422,7 @@ TVM_DLL Pass BF16StorageLegalize();
420422
* \brief Legalize fp8 storage types to u8.
421423
* \return The pass.
422424
*/
423-
TVM_DLL Pass FP8StorageLegalize();
425+
TVM_DLL Pass FP8StorageLegalize(Target target);
424426

425427
/*!
426428
* \brief Inline calls to private functions

python/tvm/contrib/nvcc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def callback_libdevice_path(arch):
270270
return ""
271271

272272

273+
@tvm._ffi.register_func("tvm.contrib.nvcc.get_compute_version")
273274
def get_target_compute_version(target=None):
274275
"""Utility function to get compute capability of compilation target.
275276
@@ -406,6 +407,7 @@ def have_cudagraph():
406407
return False
407408

408409

410+
@tvm._ffi.register_func("tvm.contrib.nvcc.supports_bf16")
409411
def have_bf16(compute_version):
410412
"""Either bf16 support is provided in the compute capability or not
411413
@@ -421,6 +423,7 @@ def have_bf16(compute_version):
421423
return False
422424

423425

426+
@tvm._ffi.register_func("tvm.contrib.nvcc.supports_fp8")
424427
def have_fp8(compute_version):
425428
"""Whether fp8 support is provided in the specified compute capability or not
426429

src/driver/driver_api.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
216216
pass_list.push_back(tir::transform::TransformMmaBufferLayout());
217217
pass_list.push_back(tir::transform::LowerOpaqueBlock());
218218
pass_list.push_back(tir::transform::FlattenBuffer());
219-
pass_list.push_back(tir::transform::FP8ComputeLegalize());
220219
pass_list.push_back(tir::transform::BF16ComputeLegalize());
221220
pass_list.push_back(tir::transform::NarrowDataType(32));
222221
pass_list.push_back(tir::transform::Simplify());
@@ -570,6 +569,8 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
570569

571570
Array<Pass> mixed_pass_list;
572571

572+
mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target));
573+
573574
// VerifyVTCMLimit must occur before LowerVtcmAlloc
574575
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
575576
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
@@ -619,7 +620,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
619620
} else {
620621
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
621622
}
622-
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
623+
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target));
623624
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
624625

625626
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

src/target/llvm/codegen_llvm.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
586586
default:
587587
LOG(FATAL) << "do not support " << dtype;
588588
}
589+
} else if (dtype.code() == DataType::kE4M3Float || dtype.code() == DataType::kE5M2Float) {
590+
etype = llvm::Type::getInt8Ty(*ctx);
589591
}
590592
if (!dtype.is_scalar()) {
591593
#if TVM_LLVM_VERSION >= 110

src/target/source/codegen_cuda.cc

Lines changed: 77 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,31 @@
4141
namespace tvm {
4242
namespace codegen {
4343

44+
std::string GetFP8Type(DataType type) {
45+
std::stringstream stream;
46+
int32_t lanes = type.lanes();
47+
std::string vec;
48+
if (type.is_scalar()) {
49+
vec = "";
50+
} else if (lanes == 2) {
51+
vec = "_2";
52+
} else if (lanes == 4) {
53+
vec = "_4";
54+
} else if (lanes == 8) {
55+
vec = "_8";
56+
} else {
57+
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8";
58+
}
59+
if (type.code() == DataType::kE4M3Float) {
60+
stream << "fp8_e4" << vec << "_t";
61+
} else if (type.code() == DataType::kE5M2Float) {
62+
stream << "fp8_e5" << vec << "_t";
63+
} else {
64+
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
65+
}
66+
return stream.str();
67+
}
68+
4469
CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }
4570

4671
void CodeGenCUDA::Init(bool output_ssa) {
@@ -121,8 +146,15 @@ std::string CodeGenCUDA::Finish() {
121146
if (enable_fp8_) {
122147
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
123148
decl_stream << "#include <cuda_fp8.h>\n";
149+
decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n";
150+
decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n";
151+
decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n";
152+
decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n";
153+
decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n";
154+
decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n";
124155
decl_stream << "#endif\n\n";
125156
}
157+
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
126158

127159
if (enable_warp_shuffle_) {
128160
decl_stream << _cuda_warp_intrinsic_util;
@@ -214,17 +246,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
214246
if (t.is_scalar()) {
215247
os << "half";
216248
} else if (lanes <= 8) {
217-
// Emit CUDA code to access fp16 vector elements.
218-
//
219-
// half4 is stored as uint2
220-
//
221-
// h4.x is emitted as *(half2*)(&(u2.x)).x
222-
// h4.y is emitted as *(half2*)(&(u2.x)).y
223-
// h4.z is emitted as *(half2*)(&(u2.y)).x
224-
// h4.w is emitted as *(half2*)(&(u2.y)).y
225-
//
226-
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
227-
os << "uint" << lanes / 2;
249+
ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for half type";
250+
if (lanes <= 4) {
251+
os << "half" << lanes;
252+
} else {
253+
os << "uint" << lanes / 2;
254+
}
228255
} else {
229256
fail = true;
230257
}
@@ -271,16 +298,9 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
271298
}
272299
if (!fail) return;
273300
} else if (t.is_float8()) {
274-
if (t.is_scalar()) {
275-
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
276-
} else if (lanes == 2) {
277-
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short
278-
} else if (lanes == 4) {
279-
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
280-
} else {
281-
fail = true;
282-
}
283-
if (!fail) return;
301+
enable_fp8_ = true;
302+
os << GetFP8Type(t);
303+
return;
284304
} else if (t == DataType::Bool()) {
285305
os << "bool";
286306
return;
@@ -446,7 +466,7 @@ void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) {
446466

447467
void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
448468
std::ostream& os) { // NOLINT(*)
449-
// Delcare the result.
469+
// Declare the result.
450470
std::string sret = name_supply_->FreshName("_");
451471
this->PrintIndent();
452472
this->PrintType(t, stream);
@@ -497,7 +517,11 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
497517
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
498518
}
499519
} else if (t.is_float16()) {
500-
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
520+
if (t.lanes() <= 4) {
521+
os << vec << "." << access[i];
522+
} else {
523+
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
524+
}
501525
} else if (t.is_bfloat16()) {
502526
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
503527
} else if (t.lanes() > 4 && t.lanes() <= 8) {
@@ -543,8 +567,13 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
543567
stream << "(" << value << " << " << i % 4 * 8 << ");\n";
544568
}
545569
} else if (t.is_float16()) {
546-
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
547-
<< value << ";\n";
570+
if (t.lanes() <= 4) {
571+
stream << vec << "." << access[i] << " = " << value << ";\n";
572+
} else {
573+
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
574+
<< value << ";\n";
575+
}
576+
548577
} else if (t.is_bfloat16()) {
549578
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
550579
<< " = " << value << ";\n";
@@ -648,6 +677,16 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
648677
// Emit simple C-style type conversion.
649678
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
650679

680+
if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == DataType::kE5M2Float ||
681+
from_ty.code() == DataType::kE4M3Float || from_ty.code() == DataType::kE5M2Float) {
682+
std::ostringstream val;
683+
val << "(";
684+
PrintType(target_ty, val);
685+
val << ")(" << PrintExpr(op->value) << ")";
686+
os << val.str();
687+
return;
688+
}
689+
651690
// We could emit make_float4 like calls, but the emitted code looks
652691
// too compact to read. Emit this as vectorized unary ops.
653692
std::string sret = name_supply_->FreshName("_");
@@ -1194,9 +1233,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
11941233
std::string v = PrintExpr(op->value);
11951234
PrintVecConstructor(op->dtype, os);
11961235
os << '(';
1197-
for (int i = 0; i < lanes / 2; ++i) {
1198-
if (i != 0) os << ", ";
1199-
os << "__pack_half2(" << v << ", " << v << ")";
1236+
if (lanes <= 4) {
1237+
for (int i = 0; i < lanes / 2; ++i) {
1238+
if (i != 0) os << ", ";
1239+
os << v << ", " << v;
1240+
}
1241+
} else {
1242+
for (int i = 0; i < lanes / 2; ++i) {
1243+
if (i != 0) os << ", ";
1244+
os << "__pack_half2(" << v << ", " << v << ")";
1245+
}
12001246
}
12011247
os << ')';
12021248
return;
@@ -1448,15 +1494,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
14481494
PrintVecConstructor(t, os);
14491495
os << '(';
14501496
}
1451-
if (i % 2 == 0) {
1452-
os << "__pack_half2(" << value;
1497+
if (i == t.lanes() - 1) {
1498+
os << value << ")";
14531499
} else {
1454-
os << "," << value << ")";
1455-
if (i != t.lanes() - 1) {
1456-
os << ",";
1457-
} else {
1458-
os << ")";
1459-
}
1500+
os << value << ",";
14601501
}
14611502
return;
14621503
}

src/target/source/literal/cuda_half_t.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
2525
#define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
2626

27+
#include <string>
28+
2729
static constexpr const char* _cuda_half_t_def = R"(
2830
typedef unsigned short uint16_t;
2931
typedef unsigned char uint8_t;
@@ -379,4 +381,44 @@ static constexpr const char* _cuda_warp_intrinsic_util = R"(
379381
380382
)";
381383

384+
void declare_vector_type_extensions(std::ostringstream& stream, bool enable_fp16, bool enable_fp8) {
385+
if (enable_fp16 || enable_fp8) {
386+
stream << R"(
387+
struct __align__(8) half4 {
388+
__half x, y, z, w;
389+
__host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}
390+
__host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}
391+
)";
392+
if (enable_fp8) {
393+
stream << R"(
394+
__host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) {
395+
__nv_fp8x2_e4m3 lo_part, hi_part;
396+
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
397+
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
398+
__half2 lo_half2 = static_cast<__half2>(lo_part);
399+
__half2 hi_half2 = static_cast<__half2>(hi_part);
400+
x = reinterpret_cast<__half*>(&lo_half2)[0];
401+
y = reinterpret_cast<__half*>(&lo_half2)[1];
402+
z = reinterpret_cast<__half*>(&hi_half2)[0];
403+
w = reinterpret_cast<__half*>(&hi_half2)[1];
404+
}
405+
__host__ __device__ explicit operator __nv_fp8x4_e4m3() const {
406+
__nv_fp8x4_e4m3 result;
407+
__half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
408+
__half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
409+
__nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2);
410+
result.__x =
411+
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
412+
return result;
413+
})";
414+
}
415+
stream << R"(
416+
};
417+
__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
418+
return half4(x, y, z, w);
419+
}
420+
)";
421+
}
422+
}
423+
382424
#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_

src/tir/transforms/unsupported_dtype_legalize.cc

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,20 @@ class FP8StorageLegalizer : public StorageLegalizer {
693693

694694
namespace transform {
695695

696+
bool CheckDataTypeSupport(const Target& target, const std::string& support_func_name) {
697+
bool has_native_support = false;
698+
if (target->kind->name == "cuda") {
699+
if (const PackedFunc* get_cv =
700+
tvm::runtime::Registry::Get("tvm.contrib.nvcc.get_compute_version")) {
701+
std::string compute_version = (*get_cv)(target);
702+
if (const PackedFunc* check_support = tvm::runtime::Registry::Get(support_func_name)) {
703+
has_native_support = (*check_support)(compute_version);
704+
}
705+
}
706+
}
707+
return has_native_support;
708+
}
709+
696710
Pass BF16ComputeLegalize() {
697711
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
698712
// TODO(tvm-team): skip if the target supports bf16
@@ -713,19 +727,23 @@ Pass BF16StorageLegalize() {
713727

714728
TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize);
715729

716-
Pass FP8ComputeLegalize(String promote_dtype_str) {
730+
Pass FP8ComputeLegalize(Target target, String promote_dtype_str) {
717731
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
718-
// TODO(tvm-team): skip if the target supports fp8
732+
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
733+
return f;
734+
}
719735
return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f);
720736
};
721737
return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {});
722738
}
723739

724740
TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize);
725741

726-
Pass FP8StorageLegalize() {
727-
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
728-
// TODO(tvm-team): skip if the target supports fp8
742+
Pass FP8StorageLegalize(Target target) {
743+
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
744+
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
745+
return f;
746+
}
729747
return FP8StorageLegalizer().Legalize(f);
730748
};
731749
return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {});

0 commit comments

Comments
 (0)