Skip to content

Commit ae7b8d9

Browse files
authored
[Codegen, Cuda] Add overload for fp8x4 e5m2 <-> half4 conversion (#16787)
1 parent 69c0914 commit ae7b8d9

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

src/target/source/literal/cuda_half_t.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,28 @@ struct __align__(8) half4 {
410410
result.__x =
411411
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
412412
return result;
413-
})";
413+
}
414+
__host__ __device__ explicit half4(const __nv_fp8x4_e5m2& fp8x4) {
415+
__nv_fp8x2_e5m2 lo_part, hi_part;
416+
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
417+
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
418+
__half2 lo_half2 = static_cast<__half2>(lo_part);
419+
__half2 hi_half2 = static_cast<__half2>(hi_part);
420+
x = reinterpret_cast<__half*>(&lo_half2)[0];
421+
y = reinterpret_cast<__half*>(&lo_half2)[1];
422+
z = reinterpret_cast<__half*>(&hi_half2)[0];
423+
w = reinterpret_cast<__half*>(&hi_half2)[1];
424+
}
425+
__host__ __device__ explicit operator __nv_fp8x4_e5m2() const {
426+
__nv_fp8x4_e5m2 result;
427+
__half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
428+
__half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
429+
__nv_fp8x2_e5m2 lo_part(lo_half2), hi_part(hi_half2);
430+
result.__x =
431+
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
432+
return result;
433+
}
434+
)";
414435
}
415436
stream << R"(
416437
};

0 commit comments

Comments
 (0)