@@ -408,6 +408,7 @@ struct ggml_backend_opencl_context {
408408    cl_program program_mul_mv_id_mxfp4_f32_flat;
409409    cl_program program_mul_mm_f32_f32_l4_lm;
410410    cl_program program_mul_mm_f16_f32_l4_lm;
411+     cl_program program_mul_mm_q8_0_f32_l4_lm;
411412
412413    cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
413414    cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
@@ -480,6 +481,7 @@ struct ggml_backend_opencl_context {
480481    cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
481482    cl_kernel kernel_mul_mm_f32_f32_l4_lm;
482483    cl_kernel kernel_mul_mm_f16_f32_l4_lm;
484+     cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
483485
484486    std::vector<ProfilingInfo> profiling_info;
485487
@@ -1191,6 +1193,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11911193        GGML_LOG_CONT (" ." 
11921194    }
11931195
1196+     //  mul_mm_q8_0_f32_l4_lm
1197+     {
1198+ #ifdef  GGML_OPENCL_EMBED_KERNELS
1199+         const  std::string kernel_src {
1200+             #include  " mul_mm_q8_0_f32_l4_lm.cl.h" 
1201+         };
1202+ #else 
1203+         const  std::string kernel_src = read_file (" mul_mm_q8_0_f32_l4_lm.cl" 
1204+ #endif 
1205+         backend_ctx->program_mul_mm_q8_0_f32_l4_lm  =
1206+             build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1207+ 
1208+         CL_CHECK ((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm  = clCreateKernel (backend_ctx->program_mul_mm_q8_0_f32_l4_lm , " kernel_mul_mm_q8_0_f32_l4_lm" 
1209+         GGML_LOG_CONT (" ." 
1210+     }
1211+ 
11941212    //  mul
11951213    {
11961214#ifdef  GGML_OPENCL_EMBED_KERNELS
@@ -6961,6 +6979,44 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
69616979                backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
69626980                return ;
69636981            }
6982+             case  GGML_TYPE_Q8_0: {
6983+                 if  (ne11 < 32 ) {
6984+                     break ;
6985+                 }
6986+                 kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm ;
6987+                 nth0 = 128 ; //  calculated as (BM*BN)/(TM*TN)
6988+ 
6989+                 int  batch_stride_a = ne00*ne01;
6990+                 int  batch_stride_b = ne10*ne11;
6991+                 int  batch_stride_d = ne0*ne1;
6992+ 
6993+                 CL_CHECK (clSetKernelArg (kernel,  0 , sizeof (cl_mem),   &extra0_q8_0->q ));
6994+                 CL_CHECK (clSetKernelArg (kernel,  1 , sizeof (cl_mem),   &extra0_q8_0->d ));
6995+                 CL_CHECK (clSetKernelArg (kernel,  2 , sizeof (cl_mem),   &extra1->data_device ));
6996+                 CL_CHECK (clSetKernelArg (kernel,  3 , sizeof (cl_ulong), &offset1));
6997+                 CL_CHECK (clSetKernelArg (kernel,  4 , sizeof (cl_mem),   &extrad->data_device ));
6998+                 CL_CHECK (clSetKernelArg (kernel,  5 , sizeof (cl_ulong), &offsetd));
6999+                 CL_CHECK (clSetKernelArg (kernel,  6 , sizeof (int ),      &ne00));
7000+                 CL_CHECK (clSetKernelArg (kernel,  7 , sizeof (int ),      &ne01));
7001+                 CL_CHECK (clSetKernelArg (kernel,  8 , sizeof (int ),      &ne02));
7002+                 CL_CHECK (clSetKernelArg (kernel,  9 , sizeof (int ),      &ne11));
7003+                 CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ),      &ne12));
7004+                 CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (int ),      &ne10)); //  stride_a
7005+                 CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (int ),      &ne10)); //  stride_b
7006+                 CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ),      &ne01)); //  stride_d
7007+                 CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ),      &batch_stride_a));
7008+                 CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ),      &batch_stride_b));
7009+                 CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (int ),      &batch_stride_d));
7010+                 CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (int ),      &r2));
7011+                 CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ),      &r3));
7012+ 
7013+                 //  64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
7014+                 size_t  global_work_size[] = {(size_t )(CEIL_DIV (ne01, 64 )*nth0), (size_t )(CEIL_DIV (ne11, 64 )), (size_t )ne12*ne13};
7015+                 size_t  local_work_size[] = {(size_t )nth0, 1 , 1 };
7016+ 
7017+                 backend_ctx->enqueue_ndrange_kernel (kernel, 3 , global_work_size, local_work_size, dst);
7018+                 return ;
7019+             }
69647020            default :
69657021                break ;
69667022        }
0 commit comments