diff --git a/ReadMe.md b/ReadMe.md
index 510b37b..fe6ebd9 100644
--- a/ReadMe.md
+++ b/ReadMe.md
@@ -1,4 +1,8 @@
# Test Case
|Engine| Type|Command line|
|----|---|---|
-|DML|MHA|`cross_runner.exe --iters 1 --type mha_dml mha_opts --data_type fp16 --layout nchw --mha_type qkv --shape_input 2,64,8,3,160`|
\ No newline at end of file
+|DML|MHA|`cross_runner.exe --iters 1 --type mha_dml mha_opts --data_type fp16 --layout nchw --mha_type qkv --shape_input 2,64,8,3,160`|
+|CM|GEMM_QK_QKV|`cross_runner.exe --iters 100 --no_conform=0 --type gemm_cm gemm_opts --gemm_type qk_qkv --data_type fp16 --layout nchw --shape_a 2,64,8,3,160 gemm_cm_opts --large_grf --tile_m 16 --tile_k 80 --tile_n 64 --lws_x 1 --lws_y 1 --lws_z 2 --slice_k 2 --dump_asm`|
+|CM|GEMM_QK_QKV dpas|`cross_runner.exe --iters 100 --no_conform=0 --type gemm_cm gemm_opts --gemm_type qk_qkv --use_dpas=1 --data_type fp16 --layout nchw --shape_a 2,64,8,3,160 gemm_cm_opts --large_grf --tile_m 16 --tile_k 80 --tile_n 64 --lws_x 8 --lws_y 2 --lws_z 1 --dump_asm`|
+|CM|GEMM_AB| `cross_runner.exe --iters 100 --no_conform=0 --type gemm_cm gemm_opts --gemm_type ab --data_type fp16 --layout nchw --shape_a 1,1,8192,320 --shape_b 1,1,320,320 --b_managed gemm_cm_opts --large_grf --tile_m 16 --tile_k 16 --tile_n 64 --dump_asm`|
+|DML|GEMM_AB| `cross_runner.exe --iters 100 --type gemm_dml gemm_opts --data_type fp16 --layout nchw --gemm_type ab --shape_a 1,1,8192,320 --shape_b 1,1,320,320 --b_managed`|
\ No newline at end of file
diff --git a/tools/cross_runner/kernels/gemm_nchw_fp16.cpp b/tools/cross_runner/kernels/gemm_nchw_fp16.cpp
index f38532b..148273d 100644
--- a/tools/cross_runner/kernels/gemm_nchw_fp16.cpp
+++ b/tools/cross_runner/kernels/gemm_nchw_fp16.cpp
@@ -69,7 +69,8 @@ extern "C" _GENX_MAIN_ void gemm_nchw_fp16(
matrix
accu_out = accu; // if DT_ACCU == DT then compiler removes this line
- accu_out *= DT(SCALE);
+ accu_out *= DT(ALPHA);
+ accu_out += DT(BETA);
#pragma unroll
for(uint32_t i = 0; i < TILE_M; i++)
diff --git a/tools/cross_runner/kernels/mha_qk_qkv_gemm_dpas.cpp b/tools/cross_runner/kernels/mha_qk_qkv_gemm_dpas.cpp
new file mode 100644
index 0000000..f77e30d
--- /dev/null
+++ b/tools/cross_runner/kernels/mha_qk_qkv_gemm_dpas.cpp
@@ -0,0 +1,170 @@
+#include
+#include
+
+#if !defined(EMPTY)
+
+#define SIZE_OF_HF16_BYTE 2
+
+#define WG_TILE_M 64
+#define WG_TILE_N 64
+
+#define SG_TILE_M 8
+#define SG_TILE_N 32
+
+#define SG_TILE_NUM_ROWS 8
+#define HALF half
+#define FLOAT float
+#define DIM_X 0
+#define DIM_Y 1
+#define DIM_Z 2
+#define BASE_OUTPUT_OFFSET 0
+
+_GENX_ inline void myDPAS8(matrix_ref matA,
+ matrix_ref matB,
+ matrix_ref result)
+{
+ result = cm_dpas(result.format(), matB.format(), matA.format());
+}
+
+#endif
+
+extern "C" _GENX_MAIN_ void
+mha_qk_qkv_gemm_dpas(SurfaceIndex INMTXa[[type("buffer_t half")]], // 0 input qkv surface
+ SurfaceIndex OUTMTX[[type("buffer_t half")]] // 1 output qxk surface
+) {
+#if !defined(EMPTY)
+
+ //A matrix format: [K/16][M][16k]
+ //A tile: 32Mx16K
+ vector readA1;//M=0..7,K=0..15
+ matrix_ref readA1_m = readA1.format();
+
+ //B matrix format: [K/16][N/8][8K][8N][2K]
+ //B tile: 32Nx16K
+ matrix readB1;//N=0..7,K=0..15
+ matrix readB2;//N=8..15,K=0..15
+ matrix readB3;//N=16..23,K=0..15
+ matrix readB4;//N=24..32,K=0..15
+
+ matrix_ref readB1_m = readB1.format();
+ matrix_ref readB2_m = readB2.format();
+ matrix_ref readB3_m = readB3.format();
+ matrix_ref readB4_m = readB4.format();
+
+ matrix result11;
+ matrix result12;
+ matrix result13;
+ matrix result14;
+
+ matrix_ref result11ref = result11;
+ matrix_ref result12ref = result12;
+ matrix_ref result13ref = result13;
+ matrix_ref result14ref = result14;
+
+ uint gidY = cm_group_id(DIM_Y);
+ uint gidX = cm_group_id(DIM_X);
+ uint gidZ = cm_group_id(DIM_Z);
+ uint tidX = cm_local_id(DIM_X);
+ uint tidY = cm_local_id(DIM_Y);
+ uint tidZ = cm_local_id(DIM_Z);
+
+ //input surface control variables
+ const unsigned input_head_count_stride_qkv = SIZE_HEAD_SIZE * 3;
+ const unsigned input_sequence_stride_qkv = SIZE_NUM_HEADS * input_head_count_stride_qkv;
+ const unsigned input_batch_stride_qkv = SIZE_SEQ_LEN * input_sequence_stride_qkv;
+ const unsigned input_surface_tile_base_offset_q = tidX * SG_TILE_M + gidX * WG_TILE_M;
+ const unsigned input_surface_tile_base_offset_k = tidY * SG_TILE_N + gidY * WG_TILE_N;
+ const unsigned input_cacheline_stride = input_sequence_stride_qkv * SIZE_OF_HF16_BYTE;
+ const unsigned input_cacheline_stride_rows = SG_TILE_NUM_ROWS * input_cacheline_stride;
+
+ //output surface control variables
+ const unsigned output_head_count_stride = SIZE_SEQ_LEN * SIZE_SEQ_LEN;
+ const unsigned output_batch_stride = SIZE_NUM_HEADS * output_head_count_stride;
+ const unsigned output_tile_base_offset = ((tidX * SG_TILE_M + gidX * WG_TILE_M) * SIZE_SEQ_LEN) + (tidY * SG_TILE_N + gidY * WG_TILE_N);
+ //const int sg_tile_base_offset = tidX * SG_TILE_M * SIZE_SEQ_LEN + tidY * SG_TILE_N;
+
+ const unsigned batch_count = gidZ/SIZE_NUM_HEADS;
+ const unsigned head_count = gidZ % SIZE_NUM_HEADS;
+ const unsigned input_surface_batch_start_offset = input_batch_stride_qkv * batch_count;
+ const unsigned output_surface_batch_start_offset = output_batch_stride * batch_count;
+
+ const unsigned head_start_offset = output_head_count_stride * head_count;
+ const unsigned output_surface_tile_start_offset = BASE_OUTPUT_OFFSET + output_surface_batch_start_offset + head_start_offset + output_tile_base_offset; // + sg_tile_base_offset;
+ const unsigned input_surface_tile_start_offset_q = head_count * input_head_count_stride_qkv + batch_count * input_surface_batch_start_offset;
+ const unsigned input_surface_tile_start_offset_k = head_count * input_head_count_stride_qkv + batch_count * input_surface_batch_start_offset + SIZE_HEAD_SIZE;
+
+ //init the accumulators
+ result11.select_all() = 0;
+ result12.select_all() = 0;
+ result13.select_all() = 0;
+ result14.select_all() = 0;
+
+ for (int head_size_step = 0; head_size_step < SIZE_HEAD_SIZE; head_size_step += 16) //iterates to process the entire K for A and B.
+ {
+ const bool is_head_size_step_multiple_16 = (head_size_step + 16 <= SIZE_HEAD_SIZE) ? true : false;
+ const unsigned input_surface_tile_start_q = (input_surface_tile_start_offset_q + head_size_step + input_surface_tile_base_offset_q * input_sequence_stride_qkv) * SIZE_OF_HF16_BYTE;;
+ const unsigned input_surface_tile_start_k = (input_surface_tile_start_offset_k + head_size_step + input_surface_tile_base_offset_k * input_sequence_stride_qkv) * SIZE_OF_HF16_BYTE;;
+ readA1.select_all() = 0.0;
+ readB1.select_all() = 0.0;
+ readB2.select_all() = 0.0;
+ readB3.select_all() = 0.0;
+ readB4.select_all() = 0.0;
+
+ #pragma unroll
+ for (int row = 0; row < SG_TILE_NUM_ROWS; row++)
+ {
+ const unsigned row_offset_in_bytes = row * SG_TILE_NUM_ROWS * SIZE_OF_HF16_BYTE;
+ const unsigned rowX2 = row * 2;
+ const unsigned input_surface_strride = row * input_cacheline_stride;
+ const unsigned input_surface_q_CL1 = input_surface_tile_start_q + input_surface_strride;
+ const unsigned input_surface_k_CL1 = input_surface_tile_start_k + input_surface_strride;
+ const unsigned input_surface_k_CL2 = input_surface_k_CL1 + input_cacheline_stride_rows;
+ const unsigned input_surface_k_CL3 = input_surface_k_CL2 + input_cacheline_stride_rows;
+ const unsigned input_surface_k_CL4 = input_surface_k_CL3 + input_cacheline_stride_rows;
+
+ // 8M x 16K (one complete row per iteration)
+ if(is_head_size_step_multiple_16)
+ {
+ readA1.select<16,1>(row_offset_in_bytes).format() = cm_load(INMTXa, input_surface_q_CL1);
+ readB1.select<8,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL1);
+ readB2.select<8,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL2);
+ readB3.select<8,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL3);
+ readB4.select<8,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL4);
+ }
+ else
+ {
+ readA1.select<8,1>(row_offset_in_bytes).format() = cm_load(INMTXa, input_surface_q_CL1);
+ readB1.select<4,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL1);
+ readB2.select<4,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL2);
+ readB3.select<4,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL3);
+ readB4.select<4,1,2,1>(0, rowX2).format() = cm_load(INMTXa, input_surface_k_CL4);
+ }
+ }
+
+ myDPAS8(readA1_m, readB1_m, result11ref);
+ myDPAS8(readA1_m, readB2_m, result12ref);
+ myDPAS8(readA1_m, readB3_m, result13ref);
+ myDPAS8(readA1_m, readB4_m, result14ref);
+ }
+
+ vector result_hf16_CL1 = 0.0;
+ result11 *= HALF(ALPHA);
+ result12 *= HALF(ALPHA);
+ result13 *= HALF(ALPHA);
+ result14 *= HALF(ALPHA);
+
+ #pragma unroll
+ for(int j = 0; j < SG_TILE_NUM_ROWS; j++)
+ {
+ const unsigned write_index_base = j * SIZE_SEQ_LEN;
+ const unsigned write_index_0 = (output_surface_tile_start_offset + write_index_base) * SIZE_OF_HF16_BYTE;
+
+ result_hf16_CL1.select<8, 1>(0) = result11ref.select<1, 1, 8, 1>(j, 0);
+ result_hf16_CL1.select<8, 1>(8) = result12ref.select<1, 1, 8, 1>(j, 0);
+ result_hf16_CL1.select<8, 1>(16) = result13ref.select<1, 1, 8, 1>(j, 0);
+ result_hf16_CL1.select<8, 1>(24) = result14ref.select<1, 1, 8, 1>(j, 0);
+
+ cm_store(OUTMTX, write_index_0, result_hf16_CL1.format());
+ }
+#endif // !defined(EMPTY)
+}
diff --git a/tools/cross_runner/kernels/mha_qk_qkv_gemm_fp16.cpp b/tools/cross_runner/kernels/mha_qk_qkv_gemm_fp16.cpp
index 90551c5..c27dd8b 100644
--- a/tools/cross_runner/kernels/mha_qk_qkv_gemm_fp16.cpp
+++ b/tools/cross_runner/kernels/mha_qk_qkv_gemm_fp16.cpp
@@ -7,6 +7,7 @@
#else
#define DT_ACCU DT
#endif
+#define BASE_OUTPUT_OFFSET 0
#define LOAD_SIMD_SIZE 16
#define K_PER_LOAD 8
@@ -56,9 +57,9 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(
)
{
- const uint32_t thread_id_0 = cm_group_id(0) * cm_local_size(0) + cm_local_id(0);
- const uint32_t thread_id_1 = cm_group_id(1) * cm_local_size(1) + cm_local_id(1);
- const uint32_t thread_id_2 = cm_group_id(2) * cm_local_size(2) + cm_local_id(2);
+ const uint32_t thread_id_0 = cm_group_id(0) * LWS_SIZE_X + cm_local_id(0);
+ const uint32_t thread_id_1 = cm_group_id(1) * LWS_SIZE_Y + cm_local_id(1);
+ const uint32_t thread_id_2 = cm_group_id(2) * LWS_SIZE_Z + cm_local_id(2);
const uint32_t k_slice_thread_offset = cm_local_id(2);
const uint32_t batch_thread_offset = cm_group_id(2) / SIZE_NUM_HEADS;
@@ -82,6 +83,7 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(
#if SLM_KN_SHARING
cm_slm_init(TILE_K * TILE_N * sizeof(DT));
+ uint slm_buffer = cm_slm_alloc(TILE_K * TILE_N * sizeof(DT));
const uint32_t th_local_id = cm_local_id(1);
#endif
@@ -154,7 +156,13 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(
#pragma unroll
for(uint32_t j = 0; j < TILE_M; j++)
{
+#if ACCU_IS_FP32
+ vector input_b_fp32 = vector(input_b);
+ vector input_a_fp32 = vector(input.select<1, 1, 1, 1>(j, k_chunk * ks + k * packed_eles + i).replicate());
+ accu.select<1, 1, TILE_N, 1>(j, 0) += input_b_fp32 * input_a_fp32;
+#else
accu.select<1, 1, TILE_N, 1>(j, 0) += input_b * input.select<1, 1, 1, 1>(j, k_chunk * ks + k * packed_eles + i).replicate();
+#endif
}
#endif
}
@@ -175,7 +183,16 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(
#pragma unroll
for(uint32_t m = 0; m < TILE_M; m++)
{
+
+#if 0
+ vector input_b_fp32 = vector(input_b);
+ vector input_a_fp32 = vector(input.select<1, 1, 1, 1>(j, k_chunk * ks + k * packed_eles + i).replicate());
+ accu.select<1, 1, TILE_N, 1>(j, 0) += input_b_fp32 * input_a_fp32;
+#else
accu.select<1, 1, TILE_N, 1>(m, 0) += input_b * input.select<1, 1, 1, 1>(m, k).replicate();
+#endif
+
+
}
}
#endif
@@ -183,7 +200,8 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(
#if SLICE_K > 1
const uint32_t TILE_N_PACKED = TILE_N / (sizeof(uint32_t)/sizeof(DT_ACCU));
cm_slm_init(TILE_M * TILE_N * sizeof(DT_ACCU) * (LWS_SIZE_Z - 1));
-
+ uint slm_buffer = cm_slm_alloc(TILE_M * TILE_N * sizeof(DT_ACCU) * (LWS_SIZE_Z - 1));
+
if(cm_local_id(2) > 0)
{
#pragma unroll
@@ -220,13 +238,15 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(
const uint32_t output_store_size = (TILE_N * sizeof(DT)) / sizeof(uint32_t);
uint32_t output_offset =
- (batch_thread_offset * SIZE_NUM_HEADS + head_thread_offset) * SIZE_M * SIZE_N * sizeof(DT)
- + thread_id_1 * TILE_M * SIZE_N * sizeof(DT)
- + (thread_id_0 * TILE_N * sizeof(DT));
+ (batch_thread_offset * SIZE_NUM_HEADS * SIZE_M * SIZE_N
+ + head_thread_offset * SIZE_M * SIZE_N
+ + thread_id_1 * TILE_M * SIZE_N
+ + thread_id_0 * TILE_N
+ + BASE_OUTPUT_OFFSET) * sizeof(DT);
matrix accu_out = accu; // if DT_ACCU == DT then compiler removes this line
- accu_out *= DT(SCALE);
+ accu_out *= DT(ALPHA);
for(uint32_t i = 0; i < TILE_M; i++)
{
vector_ref accu_0_packed = accu_out.select<1, 1, TILE_N, 1>(i, 0).format();
diff --git a/tools/cross_runner/src/gemm.h b/tools/cross_runner/src/gemm.h
index 345e333..3f0cd3b 100644
--- a/tools/cross_runner/src/gemm.h
+++ b/tools/cross_runner/src/gemm.h
@@ -21,6 +21,7 @@ enum class GemmType
GemmType_SV_S_KV,
};
+
namespace
{
// a bit of hack :)>
@@ -378,16 +379,16 @@ class Gemm : public DirectMlBaseNode
std::vector input_binds{};
input_binds.push_back({ nullptr, 0, 0 }); // tensor a
-
+
//tensor b
- if (input_1_.GetOutputDesc().flags == DML_TENSOR_FLAG_OWNED_BY_DML)
+ if (input_1_ && input_1_.GetOutputDesc().flags == DML_TENSOR_FLAG_OWNED_BY_DML)
{
- input_binds.push_back({ resource_b, 0, resource_b->GetDesc().Width });
+ input_binds.push_back({ resource_b, 0, resource_b->GetDesc().Width });
}
- else
- {
+ /* else
+ {
input_binds.push_back({ nullptr, 0, 0 });
- }
+ }*/
// tensor c
if (input_2_ && input_2_->GetOutputDesc().flags == DML_TENSOR_FLAG_OWNED_BY_DML)
@@ -395,10 +396,10 @@ class Gemm : public DirectMlBaseNode
assert(resource_c != nullptr);
input_binds.push_back({ resource_c, 0, resource_c->GetDesc().Width });
}
- else
+ /* else
{
input_binds.push_back({ nullptr, 0, 0 });
- }
+ }*/
DML_BUFFER_ARRAY_BINDING input_bind{};
input_bind.BindingCount = static_cast(input_binds.size());
@@ -450,6 +451,8 @@ class GemmBaseDispatcher : public NodeDispatcher
bool b_managed = false;
bool b_transposed = false;
+ bool use_dpas = false;
+
inline static void add_cli_options(CLI::App* opts, create_params_t& params)
{
add_data_type_cli_option(opts, "--data_type", params.dt)->required();
@@ -461,7 +464,7 @@ class GemmBaseDispatcher : public NodeDispatcher
opts->add_option("--shape_c", params.shape_c);
opts->add_flag("--b_transposed", params.b_transposed)->default_val(false);
- opts->add_flag("--b_managed", params.b_managed)->default_val(false);;
+ opts->add_flag("--b_managed", params.b_managed)->default_val(false);
opts->add_option("--alpha", params.alpha);
opts->add_option("--beta", params.beta);
@@ -469,7 +472,7 @@ class GemmBaseDispatcher : public NodeDispatcher
opts->add_flag("--fuse_softmax", params.fuse_softmax)->default_val(false);
opts->add_option("--gemm_type", params.type, "Name of the type of GEMM to run.")
- ->check(CLI::IsMember({ GemmType::GemmType_AB, GemmType::GemmType_QK_QKV, GemmType::GemmType_SV_S_QKV, GemmType::GemmType_QK_Q_KV, GemmType::GemmType_SV_S_KV }))->
+ ->check(CLI::IsMember({ GemmType::GemmType_AB, GemmType::GemmType_QK_QKV, GemmType::GemmType_SV_S_QKV, GemmType::GemmType_QK_Q_KV, GemmType::GemmType_SV_S_KV}))->
transform(CLI::Transformer(std::map{
{ "ab", GemmType::GemmType_AB },
{ "qk_qkv", GemmType::GemmType_QK_QKV },
@@ -478,6 +481,8 @@ class GemmBaseDispatcher : public NodeDispatcher
{ "sv_s_kv", GemmType::GemmType_SV_S_KV },
}, CLI::ignore_case))->required();
+ opts-> add_flag("--use_dpas", params.use_dpas)->default_val(false);
+
}
};
public:
@@ -516,7 +521,7 @@ class GemmBaseDispatcher : public NodeDispatcher
assert(!input_data_a_.empty());
assert(!input_data_b_.empty());
}
- else if (params_.type == GemmType::GemmType_QK_Q_KV)
+ else if (params_.type == GemmType::GemmType_QK_Q_KV )
{
assert(params_.shape_a.get_dims_count() == 3); // q input
assert(params_.shape_b.get_dims_count() == 5); // q_kv input
@@ -756,7 +761,7 @@ class GemmBaseDispatcher : public NodeDispatcher
{
return params_.shape_a.h;
}
- else if (params_.type == GemmType::GemmType_QK_QKV || params_.type == GemmType::GemmType_QK_Q_KV)
+ else if (params_.type == GemmType::GemmType_QK_QKV || params_.type == GemmType::GemmType_QK_Q_KV)
{
return params_.shape_a.c;
}
@@ -788,7 +793,7 @@ class GemmBaseDispatcher : public NodeDispatcher
{
return params_.b_transposed ? params_.shape_b.h : params_.shape_b.w;
}
- else if (params_.type == GemmType::GemmType_QK_QKV)
+ else if (params_.type == GemmType::GemmType_QK_QKV )
{
return params_.shape_a.c;
}
@@ -886,6 +891,7 @@ class GemmCmDispatcher : public GemmBaseDispatcher
opts->add_option("--lws_z", params.lws[2]);
opts->add_option("--slice_k", params.slice_k);
+
}
};
@@ -942,7 +948,7 @@ class GemmCmDispatcher : public GemmBaseDispatcher
}
#endif
//cm_params_.lws[0] = 32;
- cm_params_.lws[1] = 16;
+ //cm_params_.lws[1] = 16;
}
else if(params_.type == GemmType::GemmType_QK_Q_KV)
{
@@ -1037,7 +1043,8 @@ class GemmCmDispatcher : public GemmBaseDispatcher
add_define("SIZE_HEAD_SIZE", params_.shape_b.w);
}
- add_define("SCALE", params_.alpha);
+ add_define("ALPHA", params_.alpha);
+ add_define("BETA", params_.beta);
add_define("DT", "half");
@@ -1063,13 +1070,22 @@ class GemmCmDispatcher : public GemmBaseDispatcher
std::cout << build_options_final << std::endl;
}
- auto kernel_source_content = [](GemmType type)
+ auto kernel_source_content = [](GemmType type, bool dpas_flag)
{
std::string path = "";
switch (type)
{
case GemmType::GemmType_AB: path = "gemm_nchw_fp16.cpp"; break;
- case GemmType::GemmType_QK_QKV: path = "mha_qk_qkv_gemm_fp16.cpp"; break;
+ case GemmType::GemmType_QK_QKV:
+ {
+ if(dpas_flag == true)
+ {
+ path = "mha_qk_qkv_gemm_dpas.cpp";
+ }else{
+ path = "mha_qk_qkv_gemm_fp16.cpp";
+ }
+ break;
+ }
case GemmType::GemmType_SV_S_QKV: path = "mha_sv_s_qkv_gemm_fp16.cpp"; break;
case GemmType::GemmType_SV_S_KV: path = "mha_sv_s_kv_gemm_fp16.cpp"; break;
case GemmType::GemmType_QK_Q_KV: path = "mha_qk_q_kv_gemm_fp16.cpp"; break;
@@ -1084,7 +1100,7 @@ class GemmCmDispatcher : public GemmBaseDispatcher
throw std::runtime_error(msg);
}
return std::string((std::istreambuf_iterator(file)), (std::istreambuf_iterator()));
- }(params_.type);
+ }(params_.type,params_.use_dpas);
CD3DX12_SHADER_BYTECODE byte_code;
byte_code.pShaderBytecode = kernel_source_content.data();
@@ -1161,9 +1177,19 @@ class GemmCmDispatcher : public GemmBaseDispatcher
}
else if (params_.type == GemmType::GemmType_QK_QKV)
{
- gws_x = get_N() / cm_params_.tile_n; // n first
- gws_y = get_M() / cm_params_.tile_m; // m second
- gws_z = get_batch() * get_channels() * cm_params_.slice_k;
+ if(params_.use_dpas)
+ {
+ gws_x = 8;
+ gws_y = 2;
+ gws_z = 16;
+
+ }
+ else
+ {
+ gws_x = get_N() / cm_params_.tile_n; // n first
+ gws_y = get_M() / cm_params_.tile_m; // m second
+ gws_z = get_batch() * get_channels() * cm_params_.slice_k;
+ }
}
else
{