@@ -274,6 +274,92 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
274274    }
275275}
276276
277+ template  <int  NT, int  NR> static  __global__  void  dequantize_mul_mat_q4_0_test (const  void  * vx, const  void  * vy, float  * dst, const  int  ncols, const  int  nrows) {
278+     const  block_q4_0 * x = (const  block_q4_0 *) vx;
279+     const  block_q8_0 * y = (const  block_q8_0 *) vy;
280+ 
281+     const  int  bid = blockIdx .x ;
282+     const  int  tid = threadIdx .x ;
283+ 
284+     __shared__  float  tmp[NR][NT];
285+     for  (int  i = 0 ; i < NR; ++i) {
286+         tmp[i][tid] = 0 .0f ;
287+     }
288+ 
289+     const  int  nbc = (ncols + 16 *NT - 1 )/(16 *NT);
290+     const  int  nbm = ncols/QK8_0;
291+ 
292+     uint64_t  xa0;
293+     uint64_t  xa1;
294+ 
295+     const  int8_t  * xb0 = (const  int8_t  *) &xa0;
296+     const  int8_t  * xb1 = (const  int8_t  *) &xa1;
297+ 
298+     for  (int  ibc = 0 ; ibc < nbc; ++ibc) {
299+         const  int  iyb = (ibc*(16 *NT) + 16 *tid)/QK8_0;
300+         const  int  iyq = (ibc*(16 *NT) + 16 *tid)%QK8_0;
301+ 
302+         if  (iyb >= nbm) {
303+             continue ;
304+         }
305+ 
306+         const  int8_t  * yb = (const  int8_t  *) &y[iyb].qs [iyq];
307+ 
308+         const  float  dy = y[iyb].d ;
309+ 
310+         for  (int  ibr = 0 ; ibr < NR; ++ibr) {
311+             const  int  ir = bid*NR + ibr;
312+             if  (ir >= nrows) {
313+                 continue ;
314+             }
315+ 
316+             //  block offset
317+             const  int  ixo = (ir*ncols)/QK4_0 + iyb;
318+ 
319+             memcpy (&xa0, &x[ixo].qs [iyq/2  + 0 ], sizeof (uint64_t ));
320+             xa1 = xa0;
321+ 
322+             xa0 = (xa0     ) & 0x0F0F0F0F0F0F0F0F ;
323+             xa1 = (xa1 >> 4 ) & 0x0F0F0F0F0F0F0F0F ;
324+ 
325+             const  float  dx = x[ixo].d ;
326+ 
327+             //  the (int) cast is probably unnecessary, but just to make sure the result is accumulated in 32 bits
328+             tmp[ibr][tid] += (
329+                     ((int )(xb0[0 ] - 8 ))*yb[0 ]  + ((int )(xb1[0 ] - 8 ))*yb[1 ]  +
330+                     ((int )(xb0[1 ] - 8 ))*yb[2 ]  + ((int )(xb1[1 ] - 8 ))*yb[3 ]  +
331+                     ((int )(xb0[2 ] - 8 ))*yb[4 ]  + ((int )(xb1[2 ] - 8 ))*yb[5 ]  +
332+                     ((int )(xb0[3 ] - 8 ))*yb[6 ]  + ((int )(xb1[3 ] - 8 ))*yb[7 ]  +
333+                     ((int )(xb0[4 ] - 8 ))*yb[8 ]  + ((int )(xb1[4 ] - 8 ))*yb[9 ]  +
334+                     ((int )(xb0[5 ] - 8 ))*yb[10 ] + ((int )(xb1[5 ] - 8 ))*yb[11 ] +
335+                     ((int )(xb0[6 ] - 8 ))*yb[12 ] + ((int )(xb1[6 ] - 8 ))*yb[13 ] +
336+                     ((int )(xb0[7 ] - 8 ))*yb[14 ] + ((int )(xb1[7 ] - 8 ))*yb[15 ]
337+                     )*dx*dy;
338+         }
339+     }
340+ 
341+     //  reduce
342+     __syncthreads ();
343+ 
344+     for  (int  s = NT/2 ; s > 0 ; s >>= 1 ) {
345+         if  (tid < s) {
346+             for  (int  ibr = 0 ; ibr < NR; ++ibr) {
347+                 tmp[ibr][tid] += tmp[ibr][tid + s];
348+             }
349+         }
350+         __syncthreads ();
351+     }
352+ 
353+     if  (tid == 0 ) {
354+         for  (int  ibr = 0 ; ibr < NR; ++ibr) {
355+             const  int  ir = bid*NR + ibr;
356+             if  (ir < nrows) {
357+                 dst[ir] = tmp[ibr][0 ];
358+             }
359+         }
360+     }
361+ }
362+ 
277363static  void  dequantize_row_q4_0_cuda (const  void  * vx, float  * y, int  k, cudaStream_t stream) {
278364    const  int  nb = k / QK4_0;
279365    dequantize_block_q4_0<<<nb, 1 , 0 , stream>>> (vx, y);
@@ -316,9 +402,14 @@ static void dequantize_mul_mat_q4_0_cuda(const void * vx, const void * y, float
316402    //      }
317403    //  }
318404    //  dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
319-     const  int  block_size = 32 ;
320-     GGML_ASSERT (ncols % block_size == 0 );
321-     dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0 , stream>>> (vx, y, dst, ncols);
405+     // const int block_size = 32;
406+     // GGML_ASSERT(ncols % block_size == 0);
407+     // dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
408+ 
409+     const  int  NR = 1 ;  //  unroll rows (seems to not help)
410+     const  int  NT = 64 ; //  number of thrads per row
411+ 
412+     dequantize_mul_mat_q4_0_test<NT, NR><<<(nrows + NR - 1 )/NR, NT, 0 , stream>>> (vx, y, dst, ncols, nrows);
322413}
323414
324415//  TODO: optimize
0 commit comments