Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ReadMe.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Test Case
|Engine| Type|Command line|
|----|---|---|
|DML|MHA|`cross_runner.exe --iters 1 --type mha_dml mha_opts --data_type fp16 --layout nchw --mha_type qkv --shape_input 2,64,8,3,160`|
|DML|MHA|`cross_runner.exe --iters 1 --type mha_dml mha_opts --data_type fp16 --layout nchw --mha_type qkv --shape_input 2,64,8,3,160`|
|CM|GEMM_QK_QKV|`cross_runner.exe --iters 100 --no_conform=0 --type gemm_cm gemm_opts --gemm_type qk_qkv --data_type fp16 --layout nchw --shape_a 2,64,8,3,160 gemm_cm_opts --large_grf --tile_m 16 --tile_k 80 --tile_n 64 --lws_x 1 --lws_y 1 --lws_z 2 --slice_k 2 --dump_asm`|
|CM|GEMM_QK_QKV dpas|`cross_runner.exe --iters 100 --no_conform=0 --type gemm_cm gemm_opts --gemm_type qk_qkv --use_dpas=1 --data_type fp16 --layout nchw --shape_a 2,64,8,3,160 gemm_cm_opts --large_grf --tile_m 16 --tile_k 80 --tile_n 64 --lws_x 8 --lws_y 2 --lws_z 1 --dump_asm`|
|CM|GEMM_AB| `cross_runner.exe --iters 100 --no_conform=0 --type gemm_cm gemm_opts --gemm_type ab --data_type fp16 --layout nchw --shape_a 1,1,8192,320 --shape_b 1,1,320,320 --b_managed gemm_cm_opts --large_grf --tile_m 16 --tile_k 16 --tile_n 64 --dump_asm`|
|DML|GEMM_AB| `cross_runner.exe --iters 100 --type gemm_dml gemm_opts --data_type fp16 --layout nchw --gemm_type ab --shape_a 1,1,8192,320 --shape_b 1,1,320,320 --b_managed`|
3 changes: 2 additions & 1 deletion tools/cross_runner/kernels/gemm_nchw_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ extern "C" _GENX_MAIN_ void gemm_nchw_fp16(


matrix<DT, TILE_M, TILE_N> accu_out = accu; // if DT_ACCU == DT then compiler removes this line
accu_out *= DT(SCALE);
accu_out *= DT(ALPHA);
accu_out += DT(BETA);

#pragma unroll
for(uint32_t i = 0; i < TILE_M; i++)
Expand Down
170 changes: 170 additions & 0 deletions tools/cross_runner/kernels/mha_qk_qkv_gemm_dpas.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#include <cm/cm.h>
#include <cm/cmtl.h>

#if !defined(EMPTY)

#define SIZE_OF_HF16_BYTE 2

#define WG_TILE_M 64
#define WG_TILE_N 64

#define SG_TILE_M 8
#define SG_TILE_N 32

#define SG_TILE_NUM_ROWS 8
#define HALF half
#define FLOAT float
#define DIM_X 0
#define DIM_Y 1
#define DIM_Z 2
#define BASE_OUTPUT_OFFSET 0

_GENX_ inline void myDPAS8(matrix_ref<HALF, 8, 16> matA,
matrix_ref<HALF, 8, 16> matB,
matrix_ref<FLOAT, 8, 8> result)
{
result = cm_dpas<CM_PRECISION_HF, CM_PRECISION_HF, 8, 8>(result.format<FLOAT>(), matB.format<U32>(), matA.format<U32>());
}

#endif

extern "C" _GENX_MAIN_ void
mha_qk_qkv_gemm_dpas(SurfaceIndex INMTXa[[type("buffer_t half")]], // 0 input qkv surface
SurfaceIndex OUTMTX[[type("buffer_t half")]] // 1 output qxk surface
) {
#if !defined(EMPTY)

//A matrix format: [K/16][M][16k]
//A tile: 32Mx16K
vector<HALF, 128> readA1;//M=0..7,K=0..15
matrix_ref<HALF, 8, 16> readA1_m = readA1.format<HALF, 8, 16>();

//B matrix format: [K/16][N/8][8K][8N][2K]
//B tile: 32Nx16K
matrix<HALF, 8, 16> readB1;//N=0..7,K=0..15
matrix<HALF, 8, 16> readB2;//N=8..15,K=0..15
matrix<HALF, 8, 16> readB3;//N=16..23,K=0..15
matrix<HALF, 8, 16> readB4;//N=24..32,K=0..15

matrix_ref<HALF, 8, 16> readB1_m = readB1.format<HALF, 8, 16>();
matrix_ref<HALF, 8, 16> readB2_m = readB2.format<HALF, 8, 16>();
matrix_ref<HALF, 8, 16> readB3_m = readB3.format<HALF, 8, 16>();
matrix_ref<HALF, 8, 16> readB4_m = readB4.format<HALF, 8, 16>();

matrix<FLOAT, 8, 8> result11;
matrix<FLOAT, 8, 8> result12;
matrix<FLOAT, 8, 8> result13;
matrix<FLOAT, 8, 8> result14;

matrix_ref<FLOAT, 8, 8> result11ref = result11;
matrix_ref<FLOAT, 8, 8> result12ref = result12;
matrix_ref<FLOAT, 8, 8> result13ref = result13;
matrix_ref<FLOAT, 8, 8> result14ref = result14;

uint gidY = cm_group_id(DIM_Y);
uint gidX = cm_group_id(DIM_X);
uint gidZ = cm_group_id(DIM_Z);
uint tidX = cm_local_id(DIM_X);
uint tidY = cm_local_id(DIM_Y);
uint tidZ = cm_local_id(DIM_Z);

//input surface control variables
const unsigned input_head_count_stride_qkv = SIZE_HEAD_SIZE * 3;
const unsigned input_sequence_stride_qkv = SIZE_NUM_HEADS * input_head_count_stride_qkv;
const unsigned input_batch_stride_qkv = SIZE_SEQ_LEN * input_sequence_stride_qkv;
const unsigned input_surface_tile_base_offset_q = tidX * SG_TILE_M + gidX * WG_TILE_M;
const unsigned input_surface_tile_base_offset_k = tidY * SG_TILE_N + gidY * WG_TILE_N;
const unsigned input_cacheline_stride = input_sequence_stride_qkv * SIZE_OF_HF16_BYTE;
const unsigned input_cacheline_stride_rows = SG_TILE_NUM_ROWS * input_cacheline_stride;

//output surface control variables
const unsigned output_head_count_stride = SIZE_SEQ_LEN * SIZE_SEQ_LEN;
const unsigned output_batch_stride = SIZE_NUM_HEADS * output_head_count_stride;
const unsigned output_tile_base_offset = ((tidX * SG_TILE_M + gidX * WG_TILE_M) * SIZE_SEQ_LEN) + (tidY * SG_TILE_N + gidY * WG_TILE_N);
//const int sg_tile_base_offset = tidX * SG_TILE_M * SIZE_SEQ_LEN + tidY * SG_TILE_N;

const unsigned batch_count = gidZ/SIZE_NUM_HEADS;
const unsigned head_count = gidZ % SIZE_NUM_HEADS;
const unsigned input_surface_batch_start_offset = input_batch_stride_qkv * batch_count;
const unsigned output_surface_batch_start_offset = output_batch_stride * batch_count;

const unsigned head_start_offset = output_head_count_stride * head_count;
const unsigned output_surface_tile_start_offset = BASE_OUTPUT_OFFSET + output_surface_batch_start_offset + head_start_offset + output_tile_base_offset; // + sg_tile_base_offset;
const unsigned input_surface_tile_start_offset_q = head_count * input_head_count_stride_qkv + batch_count * input_surface_batch_start_offset;
const unsigned input_surface_tile_start_offset_k = head_count * input_head_count_stride_qkv + batch_count * input_surface_batch_start_offset + SIZE_HEAD_SIZE;

//init the accumulators
result11.select_all() = 0;
result12.select_all() = 0;
result13.select_all() = 0;
result14.select_all() = 0;

for (int head_size_step = 0; head_size_step < SIZE_HEAD_SIZE; head_size_step += 16) //iterates to process the entire K for A and B.
{
const bool is_head_size_step_multiple_16 = (head_size_step + 16 <= SIZE_HEAD_SIZE) ? true : false;
const unsigned input_surface_tile_start_q = (input_surface_tile_start_offset_q + head_size_step + input_surface_tile_base_offset_q * input_sequence_stride_qkv) * SIZE_OF_HF16_BYTE;;
const unsigned input_surface_tile_start_k = (input_surface_tile_start_offset_k + head_size_step + input_surface_tile_base_offset_k * input_sequence_stride_qkv) * SIZE_OF_HF16_BYTE;;
readA1.select_all() = 0.0;
readB1.select_all() = 0.0;
readB2.select_all() = 0.0;
readB3.select_all() = 0.0;
readB4.select_all() = 0.0;

#pragma unroll
for (int row = 0; row < SG_TILE_NUM_ROWS; row++)
{
const unsigned row_offset_in_bytes = row * SG_TILE_NUM_ROWS * SIZE_OF_HF16_BYTE;
const unsigned rowX2 = row * 2;
const unsigned input_surface_strride = row * input_cacheline_stride;
const unsigned input_surface_q_CL1 = input_surface_tile_start_q + input_surface_strride;
const unsigned input_surface_k_CL1 = input_surface_tile_start_k + input_surface_strride;
const unsigned input_surface_k_CL2 = input_surface_k_CL1 + input_cacheline_stride_rows;
const unsigned input_surface_k_CL3 = input_surface_k_CL2 + input_cacheline_stride_rows;
const unsigned input_surface_k_CL4 = input_surface_k_CL3 + input_cacheline_stride_rows;

// 8M x 16K (one complete row per iteration)
if(is_head_size_step_multiple_16)
{
readA1.select<16,1>(row_offset_in_bytes).format<U32>() = cm_load<U32, 8, DataSize::Default, CacheHint::Cached, CacheHint::Cached>(INMTXa, input_surface_q_CL1);
readB1.select<8,1,2,1>(0, rowX2).format<U32>() = cm_load<U32, 8, DataSize::Default, CacheHint::Cached, CacheHint::Cached>(INMTXa, input_surface_k_CL1);
readB2.select<8,1,2,1>(0, rowX2).format<U32>() = cm_load<U32, 8, DataSize::Default, CacheHint::Cached, CacheHint::Cached>(INMTXa, input_surface_k_CL2);
readB3.select<8,1,2,1>(0, rowX2).format<U32>() = cm_load<U32, 8, DataSize::Default, CacheHint::Cached, CacheHint::Cached>(INMTXa, input_surface_k_CL3);
readB4.select<8,1,2,1>(0, rowX2).format<U32>() = cm_load<U32, 8, DataSize::Default, CacheHint::Cached, CacheHint::Cached>(INMTXa, input_surface_k_CL4);
}
else
{
readA1.select<8,1>(row_offset_in_bytes).format<U32>() = cm_load<U32, 4, DataSize::Default, CacheHint::Cached, CacheHint::Cached>(INMTXa, input_surface_q_CL1);
readB1.select<4,1,2,1>(0, rowX2).format<U32>() = cm_load<U32, 4, DataSize::Default, CacheHint::Cached, CacheHint::Cached>(INMTXa, input_surface_k_CL1);
readB2.select<4,1,2,1>(0, rowX2).format<U32>() = cm_load<U32, 4, DataSize::Default, CacheHint::Cached, CacheHint::Cached>(INMTXa, input_surface_k_CL2);
readB3.select<4,1,2,1>(0, rowX2).format<U32>() = cm_load<U32, 4, DataSize::Default, CacheHint::Cached, CacheHint::Cached>(INMTXa, input_surface_k_CL3);
readB4.select<4,1,2,1>(0, rowX2).format<U32>() = cm_load<U32, 4, DataSize::Default, CacheHint::Cached, CacheHint::Cached>(INMTXa, input_surface_k_CL4);
}
}

myDPAS8(readA1_m, readB1_m, result11ref);
myDPAS8(readA1_m, readB2_m, result12ref);
myDPAS8(readA1_m, readB3_m, result13ref);
myDPAS8(readA1_m, readB4_m, result14ref);
}

vector<HALF, 32> result_hf16_CL1 = 0.0;
result11 *= HALF(ALPHA);
result12 *= HALF(ALPHA);
result13 *= HALF(ALPHA);
result14 *= HALF(ALPHA);

#pragma unroll
for(int j = 0; j < SG_TILE_NUM_ROWS; j++)
{
const unsigned write_index_base = j * SIZE_SEQ_LEN;
const unsigned write_index_0 = (output_surface_tile_start_offset + write_index_base) * SIZE_OF_HF16_BYTE;

result_hf16_CL1.select<8, 1>(0) = result11ref.select<1, 1, 8, 1>(j, 0);
result_hf16_CL1.select<8, 1>(8) = result12ref.select<1, 1, 8, 1>(j, 0);
result_hf16_CL1.select<8, 1>(16) = result13ref.select<1, 1, 8, 1>(j, 0);
result_hf16_CL1.select<8, 1>(24) = result14ref.select<1, 1, 8, 1>(j, 0);

cm_store<U32, 16, DataSize::Default, CacheHint::WriteBack, CacheHint::WriteBack>(OUTMTX, write_index_0, result_hf16_CL1.format<U32>());
}
#endif // !defined(EMPTY)
}
36 changes: 28 additions & 8 deletions tools/cross_runner/kernels/mha_qk_qkv_gemm_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#else
#define DT_ACCU DT
#endif
#define BASE_OUTPUT_OFFSET 0

#define LOAD_SIMD_SIZE 16
#define K_PER_LOAD 8
Expand Down Expand Up @@ -56,9 +57,9 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(
)
{

const uint32_t thread_id_0 = cm_group_id(0) * cm_local_size(0) + cm_local_id(0);
const uint32_t thread_id_1 = cm_group_id(1) * cm_local_size(1) + cm_local_id(1);
const uint32_t thread_id_2 = cm_group_id(2) * cm_local_size(2) + cm_local_id(2);
const uint32_t thread_id_0 = cm_group_id(0) * LWS_SIZE_X + cm_local_id(0);
const uint32_t thread_id_1 = cm_group_id(1) * LWS_SIZE_Y + cm_local_id(1);
const uint32_t thread_id_2 = cm_group_id(2) * LWS_SIZE_Z + cm_local_id(2);
const uint32_t k_slice_thread_offset = cm_local_id(2);

const uint32_t batch_thread_offset = cm_group_id(2) / SIZE_NUM_HEADS;
Expand All @@ -82,6 +83,7 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(

#if SLM_KN_SHARING
cm_slm_init(TILE_K * TILE_N * sizeof(DT));
uint slm_buffer = cm_slm_alloc(TILE_K * TILE_N * sizeof(DT));
const uint32_t th_local_id = cm_local_id(1);
#endif

Expand Down Expand Up @@ -154,7 +156,13 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(
#pragma unroll
for(uint32_t j = 0; j < TILE_M; j++)
{
#if ACCU_IS_FP32
vector<DT_ACCU, TILE_N> input_b_fp32 = vector<DT_ACCU, TILE_N>(input_b);
vector<DT_ACCU, TILE_N> input_a_fp32 = vector<DT_ACCU, TILE_N>(input.select<1, 1, 1, 1>(j, k_chunk * ks + k * packed_eles + i).replicate<TILE_N>());
accu.select<1, 1, TILE_N, 1>(j, 0) += input_b_fp32 * input_a_fp32;
#else
accu.select<1, 1, TILE_N, 1>(j, 0) += input_b * input.select<1, 1, 1, 1>(j, k_chunk * ks + k * packed_eles + i).replicate<TILE_N>();
#endif
}
#endif
}
Expand All @@ -175,15 +183,25 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(
#pragma unroll
for(uint32_t m = 0; m < TILE_M; m++)
{

#if 0
vector<DT_ACCU, TILE_N> input_b_fp32 = vector<DT_ACCU, TILE_N>(input_b);
vector<DT_ACCU, TILE_N> input_a_fp32 = vector<DT_ACCU, TILE_N>(input.select<1, 1, 1, 1>(j, k_chunk * ks + k * packed_eles + i).replicate<TILE_N>());
accu.select<1, 1, TILE_N, 1>(j, 0) += input_b_fp32 * input_a_fp32;
#else
accu.select<1, 1, TILE_N, 1>(m, 0) += input_b * input.select<1, 1, 1, 1>(m, k).replicate<TILE_N>();
#endif


}
}
#endif

#if SLICE_K > 1
const uint32_t TILE_N_PACKED = TILE_N / (sizeof(uint32_t)/sizeof(DT_ACCU));
cm_slm_init(TILE_M * TILE_N * sizeof(DT_ACCU) * (LWS_SIZE_Z - 1));

uint slm_buffer = cm_slm_alloc(TILE_M * TILE_N * sizeof(DT_ACCU) * (LWS_SIZE_Z - 1));

if(cm_local_id(2) > 0)
{
#pragma unroll
Expand Down Expand Up @@ -220,13 +238,15 @@ extern "C" _GENX_MAIN_ void mha_qk_qkv_gemm(

const uint32_t output_store_size = (TILE_N * sizeof(DT)) / sizeof(uint32_t);
uint32_t output_offset =
(batch_thread_offset * SIZE_NUM_HEADS + head_thread_offset) * SIZE_M * SIZE_N * sizeof(DT)
+ thread_id_1 * TILE_M * SIZE_N * sizeof(DT)
+ (thread_id_0 * TILE_N * sizeof(DT));
(batch_thread_offset * SIZE_NUM_HEADS * SIZE_M * SIZE_N
+ head_thread_offset * SIZE_M * SIZE_N
+ thread_id_1 * TILE_M * SIZE_N
+ thread_id_0 * TILE_N
+ BASE_OUTPUT_OFFSET) * sizeof(DT);


matrix<DT, TILE_M, TILE_N> accu_out = accu; // if DT_ACCU == DT then compiler removes this line
accu_out *= DT(SCALE);
accu_out *= DT(ALPHA);
for(uint32_t i = 0; i < TILE_M; i++)
{
vector_ref<uint32_t, output_store_size> accu_0_packed = accu_out.select<1, 1, TILE_N, 1>(i, 0).format<uint32_t>();
Expand Down
Loading