@@ -321,6 +321,7 @@ struct ggml_backend_opencl_context {
321321    cl_program program_upscale;
322322    cl_program program_concat;
323323    cl_program program_tsembd;
324+     cl_program program_mul_mv_id_q4_0_f32_8x_flat;
324325
325326    cl_kernel kernel_add, kernel_add_row;
326327    cl_kernel kernel_mul, kernel_mul_row;
@@ -366,6 +367,7 @@ struct ggml_backend_opencl_context {
366367    cl_kernel kernel_concat_f32_contiguous;
367368    cl_kernel kernel_concat_f32_non_contiguous;
368369    cl_kernel kernel_timestep_embedding;
370+     cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
369371
370372#ifdef  GGML_OPENCL_USE_ADRENO_KERNELS
371373    //  Transpose kernels
@@ -1112,7 +1114,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11121114        GGML_LOG_CONT (" ." 
11131115    }
11141116
1115-          //  repeat
1117+     //  repeat
11161118    {
11171119#ifdef  GGML_OPENCL_EMBED_KERNELS
11181120        const  std::string kernel_src {
@@ -1256,6 +1258,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
12561258        }
12571259    }
12581260
1261+     //  mul_mv_id_q4_0_f32_8x_flat
1262+     {
1263+ #ifdef  GGML_OPENCL_EMBED_KERNELS
1264+         const  std::string kernel_src {
1265+             #include  " mul_mv_id_q4_0_f32_8x_flat.cl.h" 
1266+         };
1267+ #else 
1268+         const  std::string kernel_src = read_file (" mul_mv_id_q4_0_f32_8x_flat.cl" 
1269+ #endif 
1270+         backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat  =
1271+             build_program_from_source (backend_ctx->context , backend_ctx->device , kernel_src.c_str (), compile_opts);
1272+ 
1273+         CL_CHECK ((backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat  = clCreateKernel (backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat , " kernel_mul_mv_id_q4_0_f32_8x_flat" 
1274+         GGML_LOG_CONT (" ." 
1275+     }
1276+ 
12591277    //  Adreno kernels
12601278#ifdef  GGML_OPENCL_USE_ADRENO_KERNELS
12611279    //  transpose
@@ -2178,6 +2196,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
21782196                return  op->src [1 ]->type  == GGML_TYPE_F32 && ggml_is_contiguous (op->src [0 ]) && ggml_is_contiguous (op->src [1 ]);
21792197            }
21802198            return  false ;
2199+         case  GGML_OP_MUL_MAT_ID:
2200+             if  (op->src [0 ]->type  == GGML_TYPE_Q4_0) {
2201+                 if  (op->src [1 ]->type  == GGML_TYPE_F32) {
2202+                     return  ggml_is_contiguous (op->src [0 ]) && ggml_is_contiguous (op->src [1 ]);
2203+                 }
2204+             }
2205+             return  false ;
21812206        case  GGML_OP_RESHAPE:
21822207        case  GGML_OP_VIEW:
21832208        case  GGML_OP_PERMUTE:
@@ -5536,6 +5561,136 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
55365561    }
55375562}
55385563
5564+ static  void  ggml_cl_mul_mat_id (ggml_backend_t  backend, const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
5565+     GGML_ASSERT (src0);
5566+     GGML_ASSERT (src0->extra );
5567+     GGML_ASSERT (src1);
5568+     GGML_ASSERT (src1->extra );
5569+     GGML_ASSERT (dst);
5570+     GGML_ASSERT (dst->extra );
5571+ 
5572+     const  ggml_tensor * src2 = dst->src [2 ];
5573+     GGML_ASSERT (src2);
5574+     GGML_ASSERT (src2->extra );
5575+ 
5576+     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context ;
5577+     cl_command_queue queue = backend_ctx->queue ;
5578+ 
5579+     ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra ;
5580+     ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra ;
5581+     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra ;
5582+ 
5583+     cl_ulong offset1 = extra1->offset  + src1->view_offs ;
5584+     cl_ulong offset2 = extra2->offset  + src2->view_offs ;
5585+     cl_ulong offsetd = extrad->offset  + dst->view_offs ;
5586+ 
5587+ #ifdef  GGML_OPENCL_SOA_Q
5588+     ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra ;
5589+ #endif 
5590+ 
5591+     const  int  ne00 = src0->ne [0 ];
5592+     const  int  ne01 = src0->ne [1 ];
5593+     const  int  ne02 = src0->ne [2 ];
5594+     const  int  ne03 = src0->ne [3 ];
5595+ 
5596+     const  cl_ulong nb00 = src0->nb [0 ];
5597+     const  cl_ulong nb02 = src0->nb [2 ];
5598+ 
5599+     const  int  ne10 = src1->ne [0 ];
5600+     const  int  ne11 = src1->ne [1 ];
5601+     const  int  ne12 = src1->ne [2 ];
5602+     const  int  ne13 = src1->ne [3 ];
5603+ 
5604+     const  cl_ulong nb11 = src1->nb [1 ];
5605+     const  cl_ulong nb12 = src1->nb [2 ];
5606+ 
5607+     const  int  ne20 = src2->ne [0 ];
5608+     const  int  ne21 = src2->ne [1 ];
5609+ 
5610+     const  cl_ulong nb21 = src2->nb [1 ];
5611+ 
5612+     const  int  ne0 = dst->ne [0 ];
5613+     const  int  ne1 = dst->ne [1 ];
5614+ 
5615+     const  int  r2 = ne12/ne02;
5616+     const  int  r3 = ne13/ne03;
5617+     const  int  dst_rows = ne20*ne21; //  ne20 = n_used_experts, ne21 = n_rows
5618+ 
5619+     GGML_ASSERT (ne00 == ne10);
5620+ 
5621+     int  sgs   = 32 ; //  subgroup size
5622+     int  nsg   = 1 ;  //  number of subgroups
5623+     int  nrows = 1 ;  //  number of row in src1
5624+     int  ndst  = 4 ;  //  number of values produced by each subgroup
5625+ 
5626+     cl_kernel kernel;
5627+ 
5628+     //  subgroup mat vec
5629+     switch  (src0->type ) {
5630+         case  GGML_TYPE_Q4_0: {
5631+             kernel = backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat ;
5632+ 
5633+             if  (backend_ctx->gpu_family  == INTEL) {
5634+                 sgs  = 16 ;
5635+                 nsg  = 1 ;
5636+                 ndst = 8 ;
5637+             } else  if  (backend_ctx->gpu_family  == ADRENO) {
5638+                 sgs  = 64 ;
5639+                 nsg  = 1 ;
5640+                 ndst = 8 ;
5641+             } else  {
5642+                 GGML_ASSERT (false  && " TODO: Unknown GPU" 
5643+             }
5644+ 
5645+             CL_CHECK (clSetKernelArg (kernel,  0 , sizeof (cl_mem),   &extra0_q4_0->q ));
5646+             CL_CHECK (clSetKernelArg (kernel,  1 , sizeof (cl_mem),   &extra0_q4_0->d ));
5647+             CL_CHECK (clSetKernelArg (kernel,  2 , sizeof (cl_mem),   &extra1->data_device ));
5648+             CL_CHECK (clSetKernelArg (kernel,  3 , sizeof (cl_ulong), &offset1));
5649+             CL_CHECK (clSetKernelArg (kernel,  4 , sizeof (cl_mem),   &extra2->data_device ));
5650+             CL_CHECK (clSetKernelArg (kernel,  5 , sizeof (cl_ulong), &offset2));
5651+             CL_CHECK (clSetKernelArg (kernel,  6 , sizeof (cl_mem),   &extrad->data_device ));
5652+             CL_CHECK (clSetKernelArg (kernel,  7 , sizeof (cl_ulong), &offsetd));
5653+             CL_CHECK (clSetKernelArg (kernel,  8 , sizeof (int ),      &ne00));
5654+             CL_CHECK (clSetKernelArg (kernel,  9 , sizeof (int ),      &ne01));
5655+             CL_CHECK (clSetKernelArg (kernel, 10 , sizeof (int ),      &ne02));
5656+             CL_CHECK (clSetKernelArg (kernel, 11 , sizeof (cl_ulong), &nb00));
5657+             CL_CHECK (clSetKernelArg (kernel, 12 , sizeof (cl_ulong), &nb02));
5658+             CL_CHECK (clSetKernelArg (kernel, 13 , sizeof (int ),      &ne10));
5659+             CL_CHECK (clSetKernelArg (kernel, 14 , sizeof (int ),      &ne11));
5660+             CL_CHECK (clSetKernelArg (kernel, 15 , sizeof (int ),      &ne12));
5661+             CL_CHECK (clSetKernelArg (kernel, 16 , sizeof (cl_ulong), &nb11));
5662+             CL_CHECK (clSetKernelArg (kernel, 17 , sizeof (cl_ulong), &nb12));
5663+             CL_CHECK (clSetKernelArg (kernel, 18 , sizeof (int ),      &ne20));
5664+             CL_CHECK (clSetKernelArg (kernel, 19 , sizeof (int ),      &ne21));
5665+             CL_CHECK (clSetKernelArg (kernel, 20 , sizeof (cl_ulong), &nb21));
5666+             CL_CHECK (clSetKernelArg (kernel, 21 , sizeof (int ),      &ne0));
5667+             CL_CHECK (clSetKernelArg (kernel, 22 , sizeof (int ),      &ne1));
5668+             CL_CHECK (clSetKernelArg (kernel, 23 , sizeof (int ),      &r2));
5669+             CL_CHECK (clSetKernelArg (kernel, 24 , sizeof (int ),      &r3));
5670+ 
5671+             break ;
5672+         }
5673+         default :
5674+             GGML_ASSERT (false  && " not implemented" 
5675+     }
5676+ 
5677+     int  _ne1 = 1 ;
5678+     int  ne123 = dst_rows;
5679+ 
5680+     size_t  global_work_size[] = {(size_t )(ne01+ndst*nsg-1 )/(ndst*nsg)*sgs, (size_t )(_ne1+nrows-1 )/nrows*nsg, (size_t )ne123};
5681+     size_t  local_work_size[] = {(size_t )sgs, (size_t )nsg, 1 };
5682+ 
5683+ #ifdef  GGML_OPENCL_PROFILING
5684+     cl_event evt;
5685+     CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , &evt));
5686+ 
5687+     g_profiling_info.emplace_back ();
5688+     populateProfilingInfo (g_profiling_info.back (), evt, kernel, global_work_size, local_work_size, dst);
5689+ #else 
5690+     CL_CHECK (clEnqueueNDRangeKernel (queue, kernel, 3 , NULL , global_work_size, local_work_size, 0 , NULL , NULL ));
5691+ #endif 
5692+ }
5693+ 
55395694static  void  ggml_cl_scale (ggml_backend_t  backend, const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
55405695    GGML_ASSERT (src0);
55415696    GGML_ASSERT (src0->extra );
@@ -6444,6 +6599,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
64446599            }
64456600            func = ggml_cl_mul_mat;
64466601            break ;
6602+         case  GGML_OP_MUL_MAT_ID:
6603+             if  (!any_on_device) {
6604+                 return  false ;
6605+             }
6606+             func = ggml_cl_mul_mat_id;
6607+             break ;
64476608        case  GGML_OP_SCALE:
64486609            if  (!any_on_device) {
64496610                return  false ;
0 commit comments