File tree Expand file tree Collapse file tree 1 file changed +22
-1
lines changed
src/target/source/literal Expand file tree Collapse file tree 1 file changed +22
-1
lines changed Original file line number Diff line number Diff line change @@ -410,7 +410,28 @@ struct __align__(8) half4 {
410410 result.__x =
411411 (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
412412 return result;
413- })" ;
413+ }
414+ __host__ __device__ explicit half4(const __nv_fp8x4_e5m2& fp8x4) {
415+ __nv_fp8x2_e5m2 lo_part, hi_part;
416+ lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
417+ hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
418+ __half2 lo_half2 = static_cast<__half2>(lo_part);
419+ __half2 hi_half2 = static_cast<__half2>(hi_part);
420+ x = reinterpret_cast<__half*>(&lo_half2)[0];
421+ y = reinterpret_cast<__half*>(&lo_half2)[1];
422+ z = reinterpret_cast<__half*>(&hi_half2)[0];
423+ w = reinterpret_cast<__half*>(&hi_half2)[1];
424+ }
425+ __host__ __device__ explicit operator __nv_fp8x4_e5m2() const {
426+ __nv_fp8x4_e5m2 result;
427+ __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
428+ __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
429+ __nv_fp8x2_e5m2 lo_part(lo_half2), hi_part(hi_half2);
430+ result.__x =
431+ (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
432+ return result;
433+ }
434+ )" ;
414435 }
415436 stream << R"(
416437};
You can’t perform that action at this time.
0 commit comments