diff --git a/Publications/GPU-Opt-Guide/joint-matrix/joint-matrix.cpp b/Publications/GPU-Opt-Guide/joint-matrix/joint-matrix.cpp index 87d19b1f28..ee73a115af 100644 --- a/Publications/GPU-Opt-Guide/joint-matrix/joint-matrix.cpp +++ b/Publications/GPU-Opt-Guide/joint-matrix/joint-matrix.cpp @@ -7,18 +7,19 @@ #include #include -// using joint_matrix = sycl::ext::oneapi::experimental::matrix; using use = sycl::ext::oneapi::experimental::matrix::use; using layout = sycl::ext::oneapi::experimental::matrix::layout; using bfloat16 = sycl::ext::oneapi::bfloat16; -#define SG_SZ 16 +constexpr size_t SG_SZ = 16; -#define TM 8 -#define TN SG_SZ -#define TK 16 +constexpr size_t TM = 8; +constexpr size_t TN = SG_SZ; +constexpr size_t TK = 16; -#define BF16_EPSILON 0.00781250 +constexpr float ALPHA = 2.0; + +constexpr float BF16_EPSILON = 0.00781250; template struct big_matrix { private: @@ -42,10 +43,9 @@ void matrix_multiply(big_matrix &C, big_matrix &A, sycl::queue q; q.submit([&](sycl::handler &cgh) { - sycl::accessor accC(bufC, cgh, sycl::read_write, sycl::no_init); + sycl::accessor accC(bufC, cgh, sycl::read_write); sycl::accessor accA(bufA, cgh, sycl::read_only); sycl::accessor accB(bufB, cgh, sycl::read_only); - cgh.parallel_for( sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}), [=](sycl::nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] @@ -66,30 +66,32 @@ void matrix_multiply(big_matrix &C, big_matrix &A, // For B, we assume B has been already VNNIed. sycl::ext::oneapi::experimental::matrix::joint_matrix< sycl::sub_group, bfloat16, use::b, TK, TN, - sycl::ext::intel::experimental::matrix::layout::packed> + layout::ext_intel_packed> sub_b; sycl::ext::oneapi::experimental::matrix::joint_matrix< sycl::sub_group, float, use::accumulator, TM, TN> sub_c; - joint_matrix_load(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); - for (int k = 0; k < K / TK; k += 1) { // + joint_matrix_fill(sg, sub_c, 1.0); + for (int k = 0; k < K / TK; k += 1) { joint_matrix_load( - sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK, + sg, sub_a, + accA.template get_multi_ptr() + + (sg_startx * TM) * K + k * TK, K); - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * TK / 2) * (N * 2) + - sg_starty / SG_SZ * TN * 2, - N * 2); - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + joint_matrix_load( + sg, sub_b, + accB.template get_multi_ptr() + + (k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2, + N * 2); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); } - joint_matrix_store(sg, sub_c, - accC.get_pointer() + (sg_startx * TM) * N + - sg_starty / SG_SZ * TN, - N, layout::row_major); + joint_matrix_apply(sg, sub_c, [=](float &x) { x *= ALPHA; }); + joint_matrix_store( + sg, sub_c, + accC.template get_multi_ptr() + + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, + N, layout::row_major); }); // parallel for }).wait(); // kernel end @@ -100,53 +102,43 @@ static constexpr size_t MATRIX_N = TN * 2; static constexpr size_t MATRIX_K = TK * 2; bfloat16 A[MATRIX_M][MATRIX_K]; bfloat16 B[MATRIX_K / 2][MATRIX_N * 2]; -unsigned short Aref[MATRIX_M][MATRIX_K]; -unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2]; float C[MATRIX_M][MATRIX_N]; float D[MATRIX_M][MATRIX_N]; -float make_fp32(short x) { - unsigned int y = x; +float make_fp32(bfloat16 x) { + unsigned int y = *((int *)&x); y = y << 16; float *res = reinterpret_cast(&y); return *res; } -unsigned short make_bf16(float x) { - int *res = reinterpret_cast(&x); - *res = *res >> 16; - return (unsigned short)*res; -} - void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N, int K) { for (int m = 0; m < M; m++) for (int n = 0; n < N; n++) { for (int k = 0; k < K; k++) { - short *va = (short *)(A_mem + m * K + k); - short *vb = (short *)(B_mem + k * N + n); + // Because B was assumed VNNIed + bfloat16 *va = (bfloat16 *)(A_mem + m * K + k); + bfloat16 *vb = (bfloat16 *)(B_mem + k * N + n); float acc = *((float *)(C_mem + m * N + n)); for (int i = 0; i < 2; i++) { acc += (make_fp32(va[i]) * make_fp32(vb[i])); } *((float *)(C_mem + m * N + n)) = acc; } + *((float *)(C_mem + m * N + n)) *= ALPHA; } } int main() { for (int i = 0; i < MATRIX_M; i++) { for (int j = 0; j < MATRIX_K; j++) { - // bfloat16 is created using unsigned short since conversion from float to - // bfloat16 is not supported on the host side yet A[i][j] = bfloat16(1.0f * (i + j)); - Aref[i][j] = make_bf16(1.0f * (i + j)); } } for (int i = 0; i < MATRIX_K / 2; i++) { for (int j = 0; j < MATRIX_N * 2; j++) { B[i][j] = bfloat16(2.0f * i + 3.0f * j); - Bref[i][j] = make_bf16(2.0f * i + 3.0f * j); } } for (int i = 0; i < MATRIX_M; i++) { @@ -161,13 +153,13 @@ int main() { big_matrix MA((bfloat16 *)&A); big_matrix MB((bfloat16 *)&B); matrix_multiply(MC, MA, MB); - matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M, + matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M, MATRIX_N, MATRIX_K / 2); bool res = true; for (int i = 0; i < MATRIX_M; i++) { for (int j = 0; j < MATRIX_N; j++) { - if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON) + if ((fabs(C[i][j] - D[i][j])) > BF16_EPSILON) res = false; } }