@@ -1833,7 +1833,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
18331833 const block_q4_0 * restrict x = vx ;
18341834 const block_q4_0 * restrict y = vy ;
18351835
1836- ggml_float sumf = 0.0 ;
1836+ float sumf = 0.0 ;
18371837
18381838#if defined(__ARM_NEON )
18391839 float sum0 = 0.0f ;
@@ -1928,7 +1928,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19281928#endif
19291929 }
19301930
1931- sumf = ( ggml_float )( sum0 + sum1 ) ;
1931+ sumf = sum0 + sum1 ;
19321932#elif defined(__AVX512F__ )
19331933 // Initialize accumulator with zeros
19341934 __m512 acc0 = _mm512_setzero_ps ();
@@ -1962,6 +1962,10 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19621962 __m256 acc = _mm256_setzero_ps ();
19631963
19641964 // Main loop
1965+ // TODO: figure a way to do this in a portable way
1966+ #ifdef __GNUC__
1967+ #pragma GCC unroll 16
1968+ #endif
19651969 for (int i = 0 ; i < nb ; ++ i ) {
19661970 // Compute combined scale for the block
19671971 const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
@@ -1975,20 +1979,21 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19751979 bx = _mm256_sub_epi8 ( bx , off );
19761980 by = _mm256_sub_epi8 ( by , off );
19771981
1978- // Sign-extend first 16 signed bytes into int16_t
1979- __m256i x16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( bx ) );
1980- __m256i y16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
1981- // Compute products of int16_t integers, add pairwise
1982- __m256i i32 = _mm256_madd_epi16 ( x16 , y16 );
1982+ // Get absolute values of x vectors
1983+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
19831984
1984- // Sign-extend last 16 signed bytes into int16_t vectors
1985- x16 = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( bx , 1 ) );
1986- y16 = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
1987- // Accumulate products of int16_t integers
1988- i32 = _mm256_add_epi32 ( i32 , _mm256_madd_epi16 ( x16 , y16 ) );
1985+ // Sign the values of the y vectors
1986+ const __m256i sy = _mm256_sign_epi8 (by , bx );
1987+
1988+ // Perform multiplication and create 16-bit values
1989+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
1990+
1991+ const __m256i ones = _mm256_set1_epi16 (1 );
1992+ const __m256i i32 = _mm256_madd_epi16 (ones , dot );
19891993
19901994 // Convert int32_t to float
1991- __m256 p = _mm256_cvtepi32_ps ( i32 );
1995+ const __m256 p = _mm256_cvtepi32_ps ( i32 );
1996+
19921997 // Apply the scale, and accumulate
19931998 acc = _mm256_fmadd_ps ( d , p , acc );
19941999 }
0 commit comments