diff --git a/ReadMe.md b/ReadMe.md index 510b37b..fe6ebd9 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -1,4 +1,8 @@ # Test Case |Engine| Type|Command line| |----|---|---| -|DML|MHA|`cross_runner.exe --iters 1 --type mha_dml mha_opts --data_type fp16 --layout nchw --mha_type qkv --shape_input 2,64,8,3,160`| \ No newline at end of file +|DML|MHA|`cross_runner.exe --iters 1 --type mha_dml mha_opts --data_type fp16 --layout nchw --mha_type qkv --shape_input 2,64,8,3,160`| +|CM|GEMM_QK_QKV|`cross_runner.exe --iters 100 --no_conform=0 --type gemm_cm gemm_opts --gemm_type qk_qkv --data_type fp16 --layout nchw --shape_a 2,64,8,3,160 gemm_cm_opts --large_grf --tile_m 16 --tile_k 80 --tile_n 64 --lws_x 1 --lws_y 1 --lws_z 2 --slice_k 2 --dump_asm`| +|CM|GEMM_QK_QKV dpas|`cross_runner.exe --iters 100 --no_conform=0 --type gemm_cm gemm_opts --gemm_type qk_qkv --use_dpas=1 --data_type fp16 --layout nchw --shape_a 2,64,8,3,160 gemm_cm_opts --large_grf --tile_m 16 --tile_k 80 --tile_n 64 --lws_x 8 --lws_y 2 --lws_z 1 --dump_asm`| +|CM|GEMM_AB| `cross_runner.exe --iters 100 --no_conform=0 --type gemm_cm gemm_opts --gemm_type ab --data_type fp16 --layout nchw --shape_a 1,1,8192,320 --shape_b 1,1,320,320 --b_managed gemm_cm_opts --large_grf --tile_m 16 --tile_k 16 --tile_n 64 --dump_asm`| +|DML|GEMM_AB| `cross_runner.exe --iters 100 --type gemm_dml gemm_opts --data_type fp16 --layout nchw --gemm_type ab --shape_a 1,1,8192,320 --shape_b 1,1,320,320 --b_managed`| \ No newline at end of file diff --git a/tools/cross_runner/kernels/gemm_nchw_fp16.cpp b/tools/cross_runner/kernels/gemm_nchw_fp16.cpp index f38532b..148273d 100644 --- a/tools/cross_runner/kernels/gemm_nchw_fp16.cpp +++ b/tools/cross_runner/kernels/gemm_nchw_fp16.cpp @@ -69,7 +69,8 @@ extern "C" _GENX_MAIN_ void gemm_nchw_fp16( matrix accu_out = accu; // if DT_ACCU == DT then compiler removes this line - accu_out *= DT(SCALE); + accu_out *= DT(ALPHA); + accu_out += DT(BETA); #pragma unroll for(uint32_t i = 0; i < TILE_M; i++) diff --git a/tools/cross_runner/kernels/mha_qk_qkv_gemm_dpas.cpp b/tools/cross_runner/kernels/mha_qk_qkv_gemm_dpas.cpp new file mode 100644 index 0000000..f77e30d --- /dev/null +++ b/tools/cross_runner/kernels/mha_qk_qkv_gemm_dpas.cpp @@ -0,0 +1,170 @@ +#include +#include + +#if !defined(EMPTY) + +#define SIZE_OF_HF16_BYTE 2 + +#define WG_TILE_M 64 +#define WG_TILE_N 64 + +#define SG_TILE_M 8 +#define SG_TILE_N 32 + +#define SG_TILE_NUM_ROWS 8 +#define HALF half +#define FLOAT float +#define DIM_X 0 +#define DIM_Y 1 +#define DIM_Z 2 +#define BASE_OUTPUT_OFFSET 0 + +_GENX_ inline void myDPAS8(matrix_ref matA, + matrix_ref matB, + matrix_ref result) +{ + result = cm_dpas(result.format(), matB.format(), matA.format()); +} + +#endif + +extern "C" _GENX_MAIN_ void +mha_qk_qkv_gemm_dpas(SurfaceIndex INMTXa[[type("buffer_t half")]], // 0 input qkv surface + SurfaceIndex OUTMTX[[type("buffer_t half")]] // 1 output qxk surface +) { +#if !defined(EMPTY) + + //A matrix format: [K/16][M][16k] + //A tile: 32Mx16K + vector readA1;//M=0..7,K=0..15 + matrix_ref readA1_m = readA1.format(); + + //B matrix format: [K/16][N/8][8K][8N][2K] + //B tile: 32Nx16K + matrix readB1;//N=0..7,K=0..15 + matrix readB2;//N=8..15,K=0..15 + matrix readB3;//N=16..23,K=0..15 + matrix readB4;//N=24..32,K=0..15 + + matrix_ref readB1_m = readB1.format(); + matrix_ref readB2_m = readB2.format(); + matrix_ref readB3_m = readB3.format(); + matrix_ref readB4_m = readB4.format(); + + matrix result11; + matrix result12; + matrix result13; + matrix result14; + + matrix_ref result11ref = result11; + matrix_ref result12ref = result12; + matrix_ref result13ref = result13; + matrix_ref result14ref = result14; + + uint gidY = cm_group_id(DIM_Y); + uint gidX = cm_group_id(DIM_X); + uint gidZ = cm_group_id(DIM_Z); + uint tidX = cm_local_id(DIM_X); + uint tidY = cm_local_id(DIM_Y); + uint tidZ = cm_local_id(DIM_Z); + + //input surface control variables + const unsigned input_head_count_stride_qkv = SIZE_HEAD_SIZE * 3; + const unsigned input_sequence_stride_qkv = SIZE_NUM_HEADS * input_head_count_stride_qkv; + const unsigned input_batch_stride_qkv = SIZE_SEQ_LEN * input_sequence_stride_qkv; + const unsigned input_surface_tile_base_offset_q = tidX * SG_TILE_M + gidX * WG_TILE_M; + const unsigned input_surface_tile_base_offset_k = tidY * SG_TILE_N + gidY * WG_TILE_N; + const unsigned input_cacheline_stride = input_sequence_stride_qkv * SIZE_OF_HF16_BYTE; + const unsigned input_cacheline_stride_rows = SG_TILE_NUM_ROWS * input_cacheline_stride; + + //output surface control variables + const unsigned output_head_count_stride = SIZE_SEQ_LEN * SIZE_SEQ_LEN; + const unsigned output_batch_stride = SIZE_NUM_HEADS * output_head_count_stride; + const unsigned output_tile_base_offset = ((tidX * SG_TILE_M + gidX * WG_TILE_M) * SIZE_SEQ_LEN) + (tidY * SG_TILE_N + gidY * WG_TILE_N); + //const int sg_tile_base_offset = tidX * SG_TILE_M * SIZE_SEQ_LEN + tidY * SG_TILE_N; + + const unsigned batch_count = gidZ/SIZE_NUM_HEADS; + const unsigned head_count = gidZ % SIZE_NUM_HEADS; + const unsigned input_surface_batch_start_offset = input_batch_stride_qkv * batch_count; + const unsigned output_surface_batch_start_offset = output_batch_stride * batch_count; + + const unsigned head_start_offset = output_head_count_stride * head_count; + const unsigned output_surface_tile_start_offset = BASE_OUTPUT_OFFSET + output_surface_batch_start_offset + head_start_offset + output_tile_base_offset; // + sg_tile_base_offset; + const unsigned input_surface_tile_start_offset_q = head_count * input_head_count_stride_qkv + batch_count * input_surface_batch_start_offset; + const unsigned input_surface_tile_start_offset_k = head_count * input_head_count_stride_qkv + batch_count * input_surface_batch_start_offset + SIZE_HEAD_SIZE; + + //init the accumulators + result11.select_all() = 0; + result12.select_all() = 0; + result13.select_all() = 0; + result14.select_all() = 0; + + for (int head_size_step = 0; head_size_step < SIZE_HEAD_SIZE; head_size_step += 16) //iterates to process the entire K for A and B. + { + const bool is_head_size_step_multiple_16 = (head_size_step + 16 <= SIZE_HEAD_SIZE) ? true : false; + const unsigned input_surface_tile_start_q = (input_surface_tile_start_offset_q + head_size_step + input_surface_tile_base_offset_q * input_sequence_stride_qkv) * SIZE_OF_HF16_BYTE;; + const unsigned input_surface_tile_start_k = (input_surface_tile_start_offset_k + head_size_step + input_surface_tile_base_offset_k * input_sequence_stride_qkv) * SIZE_OF_HF16_BYTE;; + readA1.select_all() = 0.0; + readB1.select_all() = 0.0; + readB2.select_all() = 0.0; + readB3.select_all() = 0.0; + readB4.select_all() = 0.0; + + #pragma unroll + for (int row = 0; row < SG_TILE_NUM_ROWS; row++) + { + const unsigned row_offset_in_bytes = row * SG_TILE_NUM_ROWS * SIZE_OF_HF16_BYTE; + const unsigned rowX2 = row * 2; + const unsigned input_surface_strride = row * input_cacheline_stride; + const unsigned input_surface_q_CL1 = input_surface_tile_start_q + input_surface_strride; + const unsigned input_surface_k_CL1 = input_surface_tile_start_k + input_surface_strride; + const unsigned input_surface_k_CL2 = input_surface_k_CL1 + input_cacheline_stride_rows; + const unsigned input_surface_k_CL3 = input_surface_k_CL2 + input_cacheline_stride_rows; + const unsigned input_surface_k_CL4 = input_surface_k_CL3 + input_cacheline_stride_rows; + + // 8M x 16K (one complete row per iteration) + if(is_head_size_step_multiple_16) + { + readA1.select<16,1>(row_offset_in_bytes).format() = cm_load(INMTXa, input_surface_q_CL1); + readB1.select<8,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL1); + readB2.select<8,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL2); + readB3.select<8,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL3); + readB4.select<8,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL4); + } + else + { + readA1.select<8,1>(row_offset_in_bytes).format() = cm_load(INMTXa, input_surface_q_CL1); + readB1.select<4,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL1); + readB2.select<4,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL2); + readB3.select<4,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL3); + readB4.select<4,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL4); + } + } + + myDPAS8(readA1_m, readB1_m, result11ref); + myDPAS8(readA1_m, readB2_m, result12ref); + myDPAS8(readA1_m, readB3_m, result13ref); + myDPAS8(readA1_m, readB4_m, result14ref); + } + + vector result_hf16_CL1 = 0.0; + result11 *= HALF(ALPHA); + result12 *= HALF(ALPHA); + result13 *= HALF(ALPHA); + result14 *= HALF(ALPHA); + + #pragma unroll + for(int j = 0; j < SG_TILE_NUM_ROWS; j++) + { + const unsigned write_index_base = j * SIZE_SEQ_LEN; + const unsigned write_index_0 = (output_surface_tile_start_offset + write_index_base) * SIZE_OF_HF16_BYTE; + + result_hf16_CL1.select<8, 1>(0) = result11ref.select<1, 1, 8, 1>(j, 0); + result_hf16_CL1.select<8, 1>(8) = result12ref.select<1, 1, 8, 1>(j, 0); + result_hf16_CL1.select<8, 1>(16) = result13ref.select<1, 1, 8, 1>(j, 0); + result_hf16_CL1.select<8, 1>(24) = result14ref.select<1, 1, 8, 1>(j, 0); + + cm_store(OUTMTX, write_index_0, result_hf16_CL1.format()); + } +#endif // !defined(EMPTY) +} diff --git a/tools/cross_runner/kernels/mha_qk_qkv_gemm_fp16.cpp b/tools/cross_runner/kernels/mha_qk_qkv_gemm_fp16.cpp index 90551c5..c27dd8b 100644 --- a/tools/cross_runner/kernels/mha_qk_qkv_gemm_fp16.cpp +++ b/tools/cross_runner/kernels/mha_qk_qkv_gemm_fp16.cpp @@ -7,6 +7,7 @@ #else #define DT_ACCU DT #endif +#define BASE_OUTPUT_OFFSET 0 #define LOAD_SIMD_SIZE 16 #define K_PER_LOAD 8 @@ -56,9 +57,9 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm( ) { - const uint32_t thread_id_0 = cm_group_id(0) * cm_local_size(0) + cm_local_id(0); - const uint32_t thread_id_1 = cm_group_id(1) * cm_local_size(1) + cm_local_id(1); - const uint32_t thread_id_2 = cm_group_id(2) * cm_local_size(2) + cm_local_id(2); + const uint32_t thread_id_0 = cm_group_id(0) * LWS_SIZE_X + cm_local_id(0); + const uint32_t thread_id_1 = cm_group_id(1) * LWS_SIZE_Y + cm_local_id(1); + const uint32_t thread_id_2 = cm_group_id(2) * LWS_SIZE_Z + cm_local_id(2); const uint32_t k_slice_thread_offset = cm_local_id(2); const uint32_t batch_thread_offset = cm_group_id(2) / SIZE_NUM_HEADS; @@ -82,6 +83,7 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm( #if SLM_KN_SHARING cm_slm_init(TILE_K * TILE_N * sizeof(DT)); + uint slm_buffer = cm_slm_alloc(TILE_K * TILE_N * sizeof(DT)); const uint32_t th_local_id = cm_local_id(1); #endif @@ -154,7 +156,13 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm( #pragma unroll for(uint32_t j = 0; j < TILE_M; j++) { +#if ACCU_IS_FP32 + vector input_b_fp32 = vector(input_b); + vector input_a_fp32 = vector(input.select<1, 1, 1, 1>(j, k_chunk * ks + k * packed_eles + i).replicate()); + accu.select<1, 1, TILE_N, 1>(j, 0) += input_b_fp32 * input_a_fp32; +#else accu.select<1, 1, TILE_N, 1>(j, 0) += input_b * input.select<1, 1, 1, 1>(j, k_chunk * ks + k * packed_eles + i).replicate(); +#endif } #endif } @@ -175,7 +183,16 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm( #pragma unroll for(uint32_t m = 0; m < TILE_M; m++) { + +#if 0 + vector input_b_fp32 = vector(input_b); + vector input_a_fp32 = vector(input.select<1, 1, 1, 1>(j, k_chunk * ks + k * packed_eles + i).replicate()); + accu.select<1, 1, TILE_N, 1>(j, 0) += input_b_fp32 * input_a_fp32; +#else accu.select<1, 1, TILE_N, 1>(m, 0) += input_b * input.select<1, 1, 1, 1>(m, k).replicate(); +#endif + + } } #endif @@ -183,7 +200,8 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm( #if SLICE_K > 1 const uint32_t TILE_N_PACKED = TILE_N / (sizeof(uint32_t)/sizeof(DT_ACCU)); cm_slm_init(TILE_M * TILE_N * sizeof(DT_ACCU) * (LWS_SIZE_Z - 1)); - + uint slm_buffer = cm_slm_alloc(TILE_M * TILE_N * sizeof(DT_ACCU) * (LWS_SIZE_Z - 1)); + if(cm_local_id(2) > 0) { #pragma unroll @@ -220,13 +238,15 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm( const uint32_t output_store_size = (TILE_N * sizeof(DT)) / sizeof(uint32_t); uint32_t output_offset = - (batch_thread_offset * SIZE_NUM_HEADS + head_thread_offset) * SIZE_M * SIZE_N * sizeof(DT) - + thread_id_1 * TILE_M * SIZE_N * sizeof(DT) - + (thread_id_0 * TILE_N * sizeof(DT)); + (batch_thread_offset * SIZE_NUM_HEADS * SIZE_M * SIZE_N + + head_thread_offset * SIZE_M * SIZE_N + + thread_id_1 * TILE_M * SIZE_N + + thread_id_0 * TILE_N + + BASE_OUTPUT_OFFSET) * sizeof(DT); matrix accu_out = accu; // if DT_ACCU == DT then compiler removes this line - accu_out *= DT(SCALE); + accu_out *= DT(ALPHA); for(uint32_t i = 0; i < TILE_M; i++) { vector_ref accu_0_packed = accu_out.select<1, 1, TILE_N, 1>(i, 0).format(); diff --git a/tools/cross_runner/src/gemm.h b/tools/cross_runner/src/gemm.h index 345e333..3f0cd3b 100644 --- a/tools/cross_runner/src/gemm.h +++ b/tools/cross_runner/src/gemm.h @@ -21,6 +21,7 @@ enum class GemmType GemmType_SV_S_KV, }; + namespace { // a bit of hack :)> @@ -378,16 +379,16 @@ class Gemm : public DirectMlBaseNode std::vector input_binds{}; input_binds.push_back({ nullptr, 0, 0 }); // tensor a - + //tensor b - if (input_1_.GetOutputDesc().flags == DML_TENSOR_FLAG_OWNED_BY_DML) + if (input_1_ && input_1_.GetOutputDesc().flags == DML_TENSOR_FLAG_OWNED_BY_DML) { - input_binds.push_back({ resource_b, 0, resource_b->GetDesc().Width }); + input_binds.push_back({ resource_b, 0, resource_b->GetDesc().Width }); } - else - { + /* else + { input_binds.push_back({ nullptr, 0, 0 }); - } + }*/ // tensor c if (input_2_ && input_2_->GetOutputDesc().flags == DML_TENSOR_FLAG_OWNED_BY_DML) @@ -395,10 +396,10 @@ class Gemm : public DirectMlBaseNode assert(resource_c != nullptr); input_binds.push_back({ resource_c, 0, resource_c->GetDesc().Width }); } - else + /* else { input_binds.push_back({ nullptr, 0, 0 }); - } + }*/ DML_BUFFER_ARRAY_BINDING input_bind{}; input_bind.BindingCount = static_cast(input_binds.size()); @@ -450,6 +451,8 @@ class GemmBaseDispatcher : public NodeDispatcher bool b_managed = false; bool b_transposed = false; + bool use_dpas = false; + inline static void add_cli_options(CLI::App* opts, create_params_t& params) { add_data_type_cli_option(opts, "--data_type", params.dt)->required(); @@ -461,7 +464,7 @@ class GemmBaseDispatcher : public NodeDispatcher opts->add_option("--shape_c", params.shape_c); opts->add_flag("--b_transposed", params.b_transposed)->default_val(false); - opts->add_flag("--b_managed", params.b_managed)->default_val(false);; + opts->add_flag("--b_managed", params.b_managed)->default_val(false); opts->add_option("--alpha", params.alpha); opts->add_option("--beta", params.beta); @@ -469,7 +472,7 @@ class GemmBaseDispatcher : public NodeDispatcher opts->add_flag("--fuse_softmax", params.fuse_softmax)->default_val(false); opts->add_option("--gemm_type", params.type, "Name of the type of GEMM to run.") - ->check(CLI::IsMember({ GemmType::GemmType_AB, GemmType::GemmType_QK_QKV, GemmType::GemmType_SV_S_QKV, GemmType::GemmType_QK_Q_KV, GemmType::GemmType_SV_S_KV }))-> + ->check(CLI::IsMember({ GemmType::GemmType_AB, GemmType::GemmType_QK_QKV, GemmType::GemmType_SV_S_QKV, GemmType::GemmType_QK_Q_KV, GemmType::GemmType_SV_S_KV}))-> transform(CLI::Transformer(std::map{ { "ab", GemmType::GemmType_AB }, { "qk_qkv", GemmType::GemmType_QK_QKV }, @@ -478,6 +481,8 @@ class GemmBaseDispatcher : public NodeDispatcher { "sv_s_kv", GemmType::GemmType_SV_S_KV }, }, CLI::ignore_case))->required(); + opts-> add_flag("--use_dpas", params.use_dpas)->default_val(false); + } }; public: @@ -516,7 +521,7 @@ class GemmBaseDispatcher : public NodeDispatcher assert(!input_data_a_.empty()); assert(!input_data_b_.empty()); } - else if (params_.type == GemmType::GemmType_QK_Q_KV) + else if (params_.type == GemmType::GemmType_QK_Q_KV ) { assert(params_.shape_a.get_dims_count() == 3); // q input assert(params_.shape_b.get_dims_count() == 5); // q_kv input @@ -756,7 +761,7 @@ class GemmBaseDispatcher : public NodeDispatcher { return params_.shape_a.h; } - else if (params_.type == GemmType::GemmType_QK_QKV || params_.type == GemmType::GemmType_QK_Q_KV) + else if (params_.type == GemmType::GemmType_QK_QKV || params_.type == GemmType::GemmType_QK_Q_KV) { return params_.shape_a.c; } @@ -788,7 +793,7 @@ class GemmBaseDispatcher : public NodeDispatcher { return params_.b_transposed ? params_.shape_b.h : params_.shape_b.w; } - else if (params_.type == GemmType::GemmType_QK_QKV) + else if (params_.type == GemmType::GemmType_QK_QKV ) { return params_.shape_a.c; } @@ -886,6 +891,7 @@ class GemmCmDispatcher : public GemmBaseDispatcher opts->add_option("--lws_z", params.lws[2]); opts->add_option("--slice_k", params.slice_k); + } }; @@ -942,7 +948,7 @@ class GemmCmDispatcher : public GemmBaseDispatcher } #endif //cm_params_.lws[0] = 32; - cm_params_.lws[1] = 16; + //cm_params_.lws[1] = 16; } else if(params_.type == GemmType::GemmType_QK_Q_KV) { @@ -1037,7 +1043,8 @@ class GemmCmDispatcher : public GemmBaseDispatcher add_define("SIZE_HEAD_SIZE", params_.shape_b.w); } - add_define("SCALE", params_.alpha); + add_define("ALPHA", params_.alpha); + add_define("BETA", params_.beta); add_define("DT", "half"); @@ -1063,13 +1070,22 @@ class GemmCmDispatcher : public GemmBaseDispatcher std::cout << build_options_final << std::endl; } - auto kernel_source_content = [](GemmType type) + auto kernel_source_content = [](GemmType type, bool dpas_flag) { std::string path = ""; switch (type) { case GemmType::GemmType_AB: path = "gemm_nchw_fp16.cpp"; break; - case GemmType::GemmType_QK_QKV: path = "mha_qk_qkv_gemm_fp16.cpp"; break; + case GemmType::GemmType_QK_QKV: + { + if(dpas_flag == true) + { + path = "mha_qk_qkv_gemm_dpas.cpp"; + }else{ + path = "mha_qk_qkv_gemm_fp16.cpp"; + } + break; + } case GemmType::GemmType_SV_S_QKV: path = "mha_sv_s_qkv_gemm_fp16.cpp"; break; case GemmType::GemmType_SV_S_KV: path = "mha_sv_s_kv_gemm_fp16.cpp"; break; case GemmType::GemmType_QK_Q_KV: path = "mha_qk_q_kv_gemm_fp16.cpp"; break; @@ -1084,7 +1100,7 @@ class GemmCmDispatcher : public GemmBaseDispatcher throw std::runtime_error(msg); } return std::string((std::istreambuf_iterator(file)), (std::istreambuf_iterator())); - }(params_.type); + }(params_.type,params_.use_dpas); CD3DX12_SHADER_BYTECODE byte_code; byte_code.pShaderBytecode = kernel_source_content.data(); @@ -1161,9 +1177,19 @@ class GemmCmDispatcher : public GemmBaseDispatcher } else if (params_.type == GemmType::GemmType_QK_QKV) { - gws_x = get_N() / cm_params_.tile_n; // n first - gws_y = get_M() / cm_params_.tile_m; // m second - gws_z = get_batch() * get_channels() * cm_params_.slice_k; + if(params_.use_dpas) + { + gws_x = 8; + gws_y = 2; + gws_z = 16; + + } + else + { + gws_x = get_N() / cm_params_.tile_n; // n first + gws_y = get_M() / cm_params_.tile_m; // m second + gws_z = get_batch() * get_channels() * cm_params_.slice_k; + } } else {