|
41 | 41 | namespace tvm { |
42 | 42 | namespace codegen { |
43 | 43 |
|
| 44 | +std::string GetFP8Type(DataType type) { |
| 45 | + std::stringstream stream; |
| 46 | + int32_t lanes = type.lanes(); |
| 47 | + std::string vec; |
| 48 | + if (type.is_scalar()) { |
| 49 | + vec = ""; |
| 50 | + } else if (lanes == 2) { |
| 51 | + vec = "_2"; |
| 52 | + } else if (lanes == 4) { |
| 53 | + vec = "_4"; |
| 54 | + } else if (lanes == 8) { |
| 55 | + vec = "_8"; |
| 56 | + } else { |
| 57 | + LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8"; |
| 58 | + } |
| 59 | + if (type.code() == DataType::kE4M3Float) { |
| 60 | + stream << "fp8_e4" << vec << "_t"; |
| 61 | + } else if (type.code() == DataType::kE5M2Float) { |
| 62 | + stream << "fp8_e5" << vec << "_t"; |
| 63 | + } else { |
| 64 | + LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; |
| 65 | + } |
| 66 | + return stream.str(); |
| 67 | +} |
| 68 | + |
44 | 69 | CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; } |
45 | 70 |
|
46 | 71 | void CodeGenCUDA::Init(bool output_ssa) { |
@@ -121,8 +146,15 @@ std::string CodeGenCUDA::Finish() { |
121 | 146 | if (enable_fp8_) { |
122 | 147 | decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n"; |
123 | 148 | decl_stream << "#include <cuda_fp8.h>\n"; |
| 149 | + decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n"; |
| 150 | + decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n"; |
| 151 | + decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n"; |
| 152 | + decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n"; |
| 153 | + decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n"; |
| 154 | + decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n"; |
124 | 155 | decl_stream << "#endif\n\n"; |
125 | 156 | } |
| 157 | + declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_); |
126 | 158 |
|
127 | 159 | if (enable_warp_shuffle_) { |
128 | 160 | decl_stream << _cuda_warp_intrinsic_util; |
@@ -214,17 +246,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) |
214 | 246 | if (t.is_scalar()) { |
215 | 247 | os << "half"; |
216 | 248 | } else if (lanes <= 8) { |
217 | | - // Emit CUDA code to access fp16 vector elements. |
218 | | - // |
219 | | - // half4 is stored as uint2 |
220 | | - // |
221 | | - // h4.x is emitted as *(half2*)(&(u2.x)).x |
222 | | - // h4.y is emitted as *(half2*)(&(u2.x)).y |
223 | | - // h4.z is emitted as *(half2*)(&(u2.y)).x |
224 | | - // h4.w is emitted as *(half2*)(&(u2.y)).y |
225 | | - // |
226 | | - ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; |
227 | | - os << "uint" << lanes / 2; |
| 249 | + ICHECK_EQ(lanes % 2, 0) << "Only support an even number of lanes for half type"; |
| 250 | + if (lanes <= 4) { |
| 251 | + os << "half" << lanes; |
| 252 | + } else { |
| 253 | + os << "uint" << lanes / 2; |
| 254 | + } |
228 | 255 | } else { |
229 | 256 | fail = true; |
230 | 257 | } |
@@ -271,16 +298,9 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) |
271 | 298 | } |
272 | 299 | if (!fail) return; |
273 | 300 | } else if (t.is_float8()) { |
274 | | - if (t.is_scalar()) { |
275 | | - os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char |
276 | | - } else if (lanes == 2) { |
277 | | - os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short |
278 | | - } else if (lanes == 4) { |
279 | | - os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int |
280 | | - } else { |
281 | | - fail = true; |
282 | | - } |
283 | | - if (!fail) return; |
| 301 | + enable_fp8_ = true; |
| 302 | + os << GetFP8Type(t); |
| 303 | + return; |
284 | 304 | } else if (t == DataType::Bool()) { |
285 | 305 | os << "bool"; |
286 | 306 | return; |
@@ -446,7 +466,7 @@ void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) { |
446 | 466 |
|
447 | 467 | void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, |
448 | 468 | std::ostream& os) { // NOLINT(*) |
449 | | - // Delcare the result. |
| 469 | + // Declare the result. |
450 | 470 | std::string sret = name_supply_->FreshName("_"); |
451 | 471 | this->PrintIndent(); |
452 | 472 | this->PrintType(t, stream); |
@@ -497,7 +517,11 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, |
497 | 517 | os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; |
498 | 518 | } |
499 | 519 | } else if (t.is_float16()) { |
500 | | - os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; |
| 520 | + if (t.lanes() <= 4) { |
| 521 | + os << vec << "." << access[i]; |
| 522 | + } else { |
| 523 | + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; |
| 524 | + } |
501 | 525 | } else if (t.is_bfloat16()) { |
502 | 526 | os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; |
503 | 527 | } else if (t.lanes() > 4 && t.lanes() <= 8) { |
@@ -543,8 +567,13 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, |
543 | 567 | stream << "(" << value << " << " << i % 4 * 8 << ");\n"; |
544 | 568 | } |
545 | 569 | } else if (t.is_float16()) { |
546 | | - stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " |
547 | | - << value << ";\n"; |
| 570 | + if (t.lanes() <= 4) { |
| 571 | + stream << vec << "." << access[i] << " = " << value << ";\n"; |
| 572 | + } else { |
| 573 | + stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " |
| 574 | + << value << ";\n"; |
| 575 | + } |
| 576 | + |
548 | 577 | } else if (t.is_bfloat16()) { |
549 | 578 | stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] |
550 | 579 | << " = " << value << ";\n"; |
@@ -648,6 +677,16 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { |
648 | 677 | // Emit simple C-style type conversion. |
649 | 678 | if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); |
650 | 679 |
|
| 680 | + if (target_ty.code() == DataType::kE4M3Float || target_ty.code() == DataType::kE5M2Float || |
| 681 | + from_ty.code() == DataType::kE4M3Float || from_ty.code() == DataType::kE5M2Float) { |
| 682 | + std::ostringstream val; |
| 683 | + val << "("; |
| 684 | + PrintType(target_ty, val); |
| 685 | + val << ")(" << PrintExpr(op->value) << ")"; |
| 686 | + os << val.str(); |
| 687 | + return; |
| 688 | + } |
| 689 | + |
651 | 690 | // We could emit make_float4 like calls, but the emitted code looks |
652 | 691 | // too compact to read. Emit this as vectorized unary ops. |
653 | 692 | std::string sret = name_supply_->FreshName("_"); |
@@ -1194,9 +1233,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO |
1194 | 1233 | std::string v = PrintExpr(op->value); |
1195 | 1234 | PrintVecConstructor(op->dtype, os); |
1196 | 1235 | os << '('; |
1197 | | - for (int i = 0; i < lanes / 2; ++i) { |
1198 | | - if (i != 0) os << ", "; |
1199 | | - os << "__pack_half2(" << v << ", " << v << ")"; |
| 1236 | + if (lanes <= 4) { |
| 1237 | + for (int i = 0; i < lanes / 2; ++i) { |
| 1238 | + if (i != 0) os << ", "; |
| 1239 | + os << v << ", " << v; |
| 1240 | + } |
| 1241 | + } else { |
| 1242 | + for (int i = 0; i < lanes / 2; ++i) { |
| 1243 | + if (i != 0) os << ", "; |
| 1244 | + os << "__pack_half2(" << v << ", " << v << ")"; |
| 1245 | + } |
1200 | 1246 | } |
1201 | 1247 | os << ')'; |
1202 | 1248 | return; |
@@ -1448,15 +1494,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val |
1448 | 1494 | PrintVecConstructor(t, os); |
1449 | 1495 | os << '('; |
1450 | 1496 | } |
1451 | | - if (i % 2 == 0) { |
1452 | | - os << "__pack_half2(" << value; |
| 1497 | + if (i == t.lanes() - 1) { |
| 1498 | + os << value << ")"; |
1453 | 1499 | } else { |
1454 | | - os << "," << value << ")"; |
1455 | | - if (i != t.lanes() - 1) { |
1456 | | - os << ","; |
1457 | | - } else { |
1458 | | - os << ")"; |
1459 | | - } |
| 1500 | + os << value << ","; |
1460 | 1501 | } |
1461 | 1502 | return; |
1462 | 1503 | } |
|
0 commit comments