@@ -235,8 +235,8 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
235235    __shared__  float  tmp[block_size]; //  separate sum for each thread
236236    tmp[tid] = 0 ;
237237
238-     for  (int  i = 0 ; i < ncols/block_size; i += 2 ) {
239-         const  int  col = i*block_size + 2 *tid;
238+     for  (int  i = 0 ; i < ncols/block_size; i += 4 ) {
239+         const  int  col = i*block_size + 4 *tid;
240240
241241        //  dequantize
242242        const  float  d0 = x[(row*ncols + col)/QK4_0].d ;
@@ -245,19 +245,21 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
245245        const  uint8_t  * p0 = x[(row*ncols + col)/QK4_0].qs ;
246246        const   int8_t  * p1 = y[col/QK8_0].qs ;
247247
248-         const  uint8_t  vui0 = p0[((row*ncols + col)%QK4_0)/2 ];
248+         const  uint8_t  vui00 = p0[((row*ncols + col)%QK4_0)/2 ];
249+         const  uint8_t  vui01 = p0[((row*ncols + col + 2 )%QK4_0)/2 ];
249250        const   int  vi10 = p1[(col + 0 )%QK8_0];
250251        const   int  vi11 = p1[(col + 1 )%QK8_0];
252+         const   int  vi12 = p1[(col + 2 )%QK8_0];
253+         const   int  vi13 = p1[(col + 3 )%QK8_0];
251254
252-         const  int  vi00 = vui0 & 0xF ;
253-         const  int  vi01 = vui0 >> 4 ;
254- 
255-         const  float  v0 = (vi00 - 8 )*vi10*d0*d1;
256-         const  float  v1 = (vi01 - 8 )*vi11*d0*d1;
255+         const  int  vi00 = vui00 & 0xF ;
256+         const  int  vi01 = vui00 >> 4 ;
257+         const  int  vi02 = vui01 & 0xF ;
258+         const  int  vi03 = vui01 >> 4 ;
257259
258260        //  matrix multiplication
259-         tmp[tid] += v0 ;
260-         tmp[tid] += v1 ;
261+         const   int  sumi = (vi00 -  8 )*vi10 + (vi01 -  8 )*vi11 + (vi02 -  8 )*vi12 + (vi03 -  8 )*vi13 ;
262+         tmp[tid] += sumi*d0*d1 ;
261263    }
262264
263265    //  sum up partial sums and write back result
0 commit comments