Skip to content

Commit 8138785

Browse files
authored
opencl: transposed gemm/gemv moe kernel with mxfp4,f32 (ggml-org#16602)
* opencl: transposed gemm/gemv moe kernel with mxfp4,f32 * add restore kernel for moe transpose * fix trailing whitespaces * resolve compilation warnings
1 parent 66b0dbc commit 8138785

File tree

5 files changed

+567
-8
lines changed

5 files changed

+567
-8
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ set(GGML_OPENCL_KERNELS
9191
mul_mv_id_q8_0_f32_flat
9292
mul_mv_id_mxfp4_f32
9393
mul_mv_id_mxfp4_f32_flat
94+
gemm_moe_mxfp4_f32
95+
gemv_moe_mxfp4_f32
9496
mul_mm_f32_f32_l4_lm
9597
mul_mm_f16_f32_l4_lm
9698
mul_mm_q8_0_f32_l4_lm

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 205 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ struct ggml_backend_opencl_context {
402402
cl_program program_conv_2d_f32;
403403
cl_program program_conv_2d_f16_f32;
404404
cl_program program_tsembd;
405+
cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32;
405406
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
406407
cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;
407408
cl_program program_mul_mv_id_mxfp4_f32;
@@ -452,7 +453,7 @@ struct ggml_backend_opencl_context {
452453
cl_kernel kernel_mul_mat_f16_f32_tiled;
453454
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
454455
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
455-
cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
456+
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
456457
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
457458
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
458459
cl_kernel kernel_convert_block_q4_0_noshuffle;
@@ -475,6 +476,7 @@ struct ggml_backend_opencl_context {
475476
cl_kernel kernel_conv_2d_f32;
476477
cl_kernel kernel_conv_2d_f16_f32;
477478
cl_kernel kernel_timestep_embedding;
479+
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
478480
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
479481
cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
480482
cl_kernel kernel_mul_mv_id_mxfp4_f32;
@@ -559,14 +561,14 @@ struct ggml_backend_opencl_context {
559561

560562
fprintf(ftrace, "[\n");
561563
for (const ProfilingInfo & info : profiling_info) {
562-
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
564+
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n",
563565
info.kernel_name.c_str(), info.cmd_queued/1000);
564-
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
566+
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n",
565567
info.kernel_name.c_str(), info.cmd_submit/1000);
566568

567-
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
569+
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n",
568570
info.kernel_name.c_str(), info.cmd_start/1000);
569-
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
571+
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n",
570572
info.kernel_name.c_str(), info.cmd_end/1000);
571573
}
572574
fclose(ftrace);
@@ -777,6 +779,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
777779
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
778780
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
779781
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
782+
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
783+
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
780784
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
781785
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
782786
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
@@ -1991,6 +1995,42 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
19911995
CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err));
19921996
GGML_LOG_CONT(".");
19931997
}
1998+
1999+
std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
2000+
" -cl-mad-enable "
2001+
" -cl-fast-relaxed-math";
2002+
2003+
// gemv_moe_mxfp4_f32
2004+
{
2005+
#ifdef GGML_OPENCL_EMBED_KERNELS
2006+
const std::string kernel_src {
2007+
#include "gemv_moe_mxfp4_f32.cl.h"
2008+
};
2009+
#else
2010+
const std::string kernel_src = read_file("gemv_moe_mxfp4_f32.cl");
2011+
#endif
2012+
backend_ctx->program_gemv_moe_mxfp4_f32 =
2013+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
2014+
2015+
CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, "kernel_gemv_moe_mxfp4_f32", &err), err));
2016+
GGML_LOG_CONT(".");
2017+
}
2018+
2019+
// gemm_moe_mxfp4_f32
2020+
{
2021+
#ifdef GGML_OPENCL_EMBED_KERNELS
2022+
const std::string kernel_src {
2023+
#include "gemm_moe_mxfp4_f32.cl.h"
2024+
};
2025+
#else
2026+
const std::string kernel_src = read_file("gemm_moe_mxfp4_f32.cl");
2027+
#endif
2028+
backend_ctx->program_gemm_moe_mxfp4_f32 =
2029+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
2030+
2031+
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err));
2032+
GGML_LOG_CONT(".");
2033+
}
19942034
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
19952035
GGML_LOG_CONT("\n");
19962036
}
@@ -3299,6 +3339,12 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
32993339
tensor->ne[2] == 1 && tensor->ne[3] == 1;
33003340
}
33013341

3342+
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
3343+
GGML_UNUSED(backend_ctx);
3344+
int ne01 = tensor->ne[1];
3345+
return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
3346+
}
3347+
33023348
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
33033349
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
33043350

@@ -3601,14 +3647,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
36013647
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
36023648
CL_CHECK(err);
36033649

3650+
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
3651+
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
3652+
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
3653+
3654+
int ne00 = tensor->ne[0];
3655+
int ne01 = tensor->ne[1];
3656+
int ne02 = tensor->ne[2];
3657+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
3658+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
3659+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
3660+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00));
3661+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01));
3662+
3663+
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
3664+
size_t local_work_size[3] = {64, 2, 1};
3665+
3666+
cl_event evt;
3667+
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
3668+
CL_CHECK(clWaitForEvents(1, &evt));
3669+
CL_CHECK(clReleaseMemObject(data_device));
3670+
tensor->extra = extra;
3671+
3672+
return;
3673+
}
3674+
#endif
36043675
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
36053676

36063677
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
36073678
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
36083679
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
36093680

3610-
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
3611-
size_t local_work_size[] = {64, 1, 1};
3681+
size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
3682+
size_t local_work_size[3] = {64, 1, 1};
36123683

36133684
cl_event evt;
36143685
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
@@ -3624,7 +3695,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
36243695
{ extra->q }
36253696
};
36263697
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
3627-
36283698
tensor->extra = extra;
36293699

36303700
return;
@@ -3751,6 +3821,33 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
37513821
ggml_nbytes(tensor), NULL, &err);
37523822
CL_CHECK(err);
37533823

3824+
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
3825+
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
3826+
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
3827+
3828+
int ne00 = tensor->ne[0];
3829+
int ne01 = tensor->ne[1];
3830+
int ne02 = tensor->ne[2];
3831+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
3832+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
3833+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
3834+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));
3835+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));
3836+
3837+
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
3838+
size_t local_work_size[3] = {64, 2, 1};
3839+
3840+
cl_event evt;
3841+
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
3842+
global_work_size, local_work_size, 0, NULL, &evt));
3843+
CL_CHECK(clWaitForEvents(1, &evt));
3844+
CL_CHECK(clEnqueueReadBuffer(
3845+
queue, data_device, CL_TRUE, offset,
3846+
size, data, 0, NULL, NULL));
3847+
CL_CHECK(clReleaseMemObject(data_device));
3848+
return;
3849+
}
3850+
#endif
37543851
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
37553852
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
37563853
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
@@ -7553,6 +7650,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
75537650
const int ne21 = src2->ne[1];
75547651

75557652
const cl_ulong nb21 = src2->nb[1];
7653+
const cl_ulong nb20 = src2->nb[0];
75567654

75577655
const int ne0 = dst->ne[0];
75587656
const int ne1 = dst->ne[1];
@@ -7692,6 +7790,105 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
76927790
break;
76937791
}
76947792
case GGML_TYPE_MXFP4: {
7793+
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
7794+
if (use_adreno_moe_kernels(backend_ctx, src0)) {
7795+
cl_int status;
7796+
7797+
size_t local_size[3] = {64, 2, 1};
7798+
size_t global_size[3] = {64, 2, 1};
7799+
7800+
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
7801+
7802+
int tile_size = 320;
7803+
if (ne12 == 1) { // for gemv
7804+
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
7805+
7806+
// create a sub_buffer for src2
7807+
cl_buffer_region region;
7808+
region.origin = offset2;
7809+
region.size = ne20 * ne21 * sizeof(int);
7810+
buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
7811+
CL_CHECK(status);
7812+
7813+
// set thread grid
7814+
global_size[0] = static_cast<size_t>(ne01);
7815+
global_size[1] = 4;
7816+
global_size[2] = static_cast<size_t>(ne20);
7817+
local_size[1] = 4;
7818+
} else { // for gemm
7819+
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
7820+
7821+
// preprocess router table
7822+
int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
7823+
void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
7824+
void * host_src2 = malloc(ne21 * nb21);
7825+
CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));
7826+
int total_experts = nb21 / nb20;
7827+
int out_idx = 0;
7828+
for (int i_expert = 0; i_expert < ne02; i_expert++) {
7829+
for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {
7830+
for (int j = 0; j < ne21; j++) {
7831+
for (int i = 0; i < ne20; i++) {
7832+
int expert = ((int *)host_src2)[j * total_experts + i];
7833+
if (i_expert == expert) {
7834+
((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);
7835+
((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));
7836+
((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);
7837+
((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);
7838+
out_idx += 4;
7839+
}
7840+
}
7841+
}
7842+
}
7843+
}
7844+
buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);
7845+
CL_CHECK(status);
7846+
7847+
// set thread grid
7848+
global_size[0] = static_cast<size_t>(tile_size);
7849+
global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);
7850+
}
7851+
7852+
// create a sub_buffer for src1
7853+
cl_buffer_region region;
7854+
region.origin = offset1;
7855+
region.size = ne10 * ne11 * ne12 * sizeof(float);
7856+
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
7857+
CL_CHECK(status);
7858+
7859+
// create image for src1
7860+
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
7861+
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
7862+
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
7863+
CL_CHECK(status);
7864+
7865+
// Set kernel args
7866+
int arg_idx = 0;
7867+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
7868+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
7869+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
7870+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
7871+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
7872+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
7873+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
7874+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
7875+
if (ne12 == 1) {
7876+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
7877+
} else {
7878+
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size));
7879+
}
7880+
7881+
// launch kernel
7882+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
7883+
7884+
// deallocate sub buffers and images
7885+
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
7886+
CL_CHECK(clReleaseMemObject(buf_src1_image));
7887+
CL_CHECK(clReleaseMemObject(buf_src2));
7888+
return;
7889+
} // else fallback to generic kernel
7890+
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
7891+
76957892
#ifdef GGML_OPENCL_SOA_Q
76967893
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;
76977894

ggml/src/ggml-opencl/kernels/cvt.cl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,27 @@ kernel void kernel_convert_block_mxfp4(
147147
}
148148
}
149149

150+
kernel void kernel_convert_block_mxfp4_trans(
151+
global struct block_mxfp4 * src0,
152+
__global uint4 * dst_q,
153+
__global uchar * dst_e,
154+
uint ne00,
155+
uint ne01
156+
) {
157+
int i00 = get_global_id(1);
158+
uint i01 = get_global_id(0);
159+
uint i02 = get_global_id(2);
160+
161+
uint ne00_blk = ne00 / QK_MXFP4;
162+
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
163+
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
164+
165+
global struct block_mxfp4 * b = src0 + src_blk_offset;
166+
167+
dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0];
168+
dst_e[dst_blk_offset] = b->e;
169+
}
170+
150171
kernel void kernel_restore_block_mxfp4(
151172
global uchar * src_q,
152173
global half * src_e,
@@ -162,6 +183,27 @@ kernel void kernel_restore_block_mxfp4(
162183
}
163184
}
164185

186+
kernel void kernel_restore_block_mxfp4_trans(
187+
__global uint4 * src_q,
188+
__global uchar * src_e,
189+
global struct block_mxfp4 * dst,
190+
uint ne00,
191+
uint ne01
192+
) {
193+
int i00 = get_global_id(1);
194+
uint i01 = get_global_id(0);
195+
uint i02 = get_global_id(2);
196+
197+
uint ne00_blk = ne00 / QK_MXFP4;
198+
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
199+
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
200+
201+
global struct block_mxfp4 * b = dst + dst_blk_offset;
202+
203+
((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];
204+
b->e = src_e[src_blk_offset];
205+
}
206+
165207
//------------------------------------------------------------------------------
166208
// block_q8_0
167209
//------------------------------------------------------------------------------

0 commit comments

Comments
 (0)