@@ -583,7 +583,7 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 bloc
583583
584584typedef struct {
585585 float d ; // delta
586- uint8_t qs [QK ]; // nibbles / quants
586+ int8_t qs [QK ]; // quants
587587} block_q8_0 ;
588588static_assert (sizeof (block_q8_0 ) == sizeof (float ) + QK , "wrong q8_0 block size/padding" );
589589
@@ -1060,9 +1060,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
10601060
10611061 for (int l = 0 ; l < QK ; ++ l ) {
10621062 const float v = x [i * QK + l ]* id ;
1063- const uint8_t vi = (int8_t )roundf (v ) + 128 ;
1064-
1065- y [i ].qs [l ] = vi ;
1063+ y [i ].qs [l ] = roundf (v );
10661064 }
10671065 }
10681066}
@@ -1095,15 +1093,99 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
10951093
10961094 for (int l = 0 ; l < 8 ; l ++ ) {
10971095 const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
1098- const float32x4_t vf = vaddq_f32 ( v , vdupq_n_f32 ( 128.5f ));
1099- const int32x4_t vi = vcvtq_s32_f32 (vf );
1096+ //TODO: rounding
1097+ const int32x4_t vi = vcvtq_s32_f32 (v );
11001098
11011099 y [i ].qs [4 * l + 0 ] = vgetq_lane_s32 (vi , 0 );
11021100 y [i ].qs [4 * l + 1 ] = vgetq_lane_s32 (vi , 1 );
11031101 y [i ].qs [4 * l + 2 ] = vgetq_lane_s32 (vi , 2 );
11041102 y [i ].qs [4 * l + 3 ] = vgetq_lane_s32 (vi , 3 );
11051103 }
11061104 }
1105+ #elif defined(__AVX2__ ) || defined(__AVX__ )
1106+ for (int i = 0 ; i < nb ; i ++ ) {
1107+ // Load elements into 4 AVX vectors
1108+ __m256 v0 = _mm256_loadu_ps ( x );
1109+ __m256 v1 = _mm256_loadu_ps ( x + 8 );
1110+ __m256 v2 = _mm256_loadu_ps ( x + 16 );
1111+ __m256 v3 = _mm256_loadu_ps ( x + 24 );
1112+ x += 32 ;
1113+
1114+ // Compute max(abs(e)) for the block
1115+ const __m256 signBit = _mm256_set1_ps ( -0.0f );
1116+ __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
1117+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
1118+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
1119+ maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
1120+
1121+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
1122+ max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
1123+ max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
1124+ const float maxScalar = _mm_cvtss_f32 ( max4 );
1125+
1126+ // Quantize these floats
1127+ const float d = maxScalar / 127.f ;
1128+ y [i ].d = d ;
1129+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f ;
1130+ const __m256 mul = _mm256_set1_ps ( id );
1131+
1132+ // Apply the multiplier
1133+ v0 = _mm256_mul_ps ( v0 , mul );
1134+ v1 = _mm256_mul_ps ( v1 , mul );
1135+ v2 = _mm256_mul_ps ( v2 , mul );
1136+ v3 = _mm256_mul_ps ( v3 , mul );
1137+
1138+ // Round to nearest integer
1139+ v0 = _mm256_round_ps ( v0 , _MM_ROUND_NEAREST );
1140+ v1 = _mm256_round_ps ( v1 , _MM_ROUND_NEAREST );
1141+ v2 = _mm256_round_ps ( v2 , _MM_ROUND_NEAREST );
1142+ v3 = _mm256_round_ps ( v3 , _MM_ROUND_NEAREST );
1143+
1144+ // Convert floats to integers
1145+ __m256i i0 = _mm256_cvtps_epi32 ( v0 );
1146+ __m256i i1 = _mm256_cvtps_epi32 ( v1 );
1147+ __m256i i2 = _mm256_cvtps_epi32 ( v2 );
1148+ __m256i i3 = _mm256_cvtps_epi32 ( v3 );
1149+
1150+ #if defined(__AVX2__ )
1151+ // Convert int32 to int16
1152+ i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1153+ i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1154+ // Convert int16 to int8
1155+ i0 = _mm256_packs_epi16 ( i0 , i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1156+
1157+ // We got our precious signed bytes, but the order is now wrong
1158+ // These AVX2 pack instructions process 16-byte pieces independently
1159+ // The following instruction is fixing the order
1160+ const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
1161+ i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
1162+
1163+ _mm256_storeu_si256 ((__m256i * )y [i ].qs , i0 );
1164+ #else
1165+ // Since we don't have in AVX some necessary functions,
1166+ // we split the registers in half and call AVX2 analogs from SSE
1167+ __m128i ni0 = _mm256_castsi256_si128 ( i0 );
1168+ __m128i ni1 = _mm256_extractf128_si256 ( i0 , 1 );
1169+ __m128i ni2 = _mm256_castsi256_si128 ( i1 );
1170+ __m128i ni3 = _mm256_extractf128_si256 ( i1 , 1 );
1171+ __m128i ni4 = _mm256_castsi256_si128 ( i2 );
1172+ __m128i ni5 = _mm256_extractf128_si256 ( i2 , 1 );
1173+ __m128i ni6 = _mm256_castsi256_si128 ( i3 );
1174+ __m128i ni7 = _mm256_extractf128_si256 ( i3 , 1 );
1175+
1176+ // Convert int32 to int16
1177+ ni0 = _mm_packs_epi32 ( ni0 , ni1 );
1178+ ni2 = _mm_packs_epi32 ( ni2 , ni3 );
1179+ ni4 = _mm_packs_epi32 ( ni4 , ni5 );
1180+ ni6 = _mm_packs_epi32 ( ni6 , ni7 );
1181+ // Convert int16 to int8
1182+ ni0 = _mm_packs_epi16 ( ni0 , ni2 );
1183+ ni4 = _mm_packs_epi16 ( ni4 , ni6 );
1184+
1185+ _mm_storeu_si128 ((__m128i * )(y [i ].qs + 0 ), ni0 );
1186+ _mm_storeu_si128 ((__m128i * )(y [i ].qs + 16 ), ni4 );
1187+ #endif
1188+ }
11071189#else
11081190 // scalar
11091191 quantize_row_q8_0_reference (x , y , k );
@@ -2508,7 +2590,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25082590
25092591 const uint8x16_t m4b = vdupq_n_u8 (0xf );
25102592 const int8x16_t s8b = vdupq_n_s8 (0x8 );
2511- const uint8x16_t u128b = vdupq_n_u8 (128 );
25122593
25132594 const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
25142595 const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
@@ -2526,21 +2607,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25262607 const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
25272608
25282609 // load y
2529- const uint8x16_t v1_0l = vld1q_u8 (y0 -> qs );
2530- const uint8x16_t v1_0h = vld1q_u8 (y0 -> qs + 16 );
2531- const uint8x16_t v1_1l = vld1q_u8 (y1 -> qs );
2532- const uint8x16_t v1_1h = vld1q_u8 (y1 -> qs + 16 );
2610+ const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2611+ const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2612+ const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2613+ const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
25332614
25342615 // interleave
2535- const uint8x16_t v1_0lz = vuzp1q_u8 (v1_0l , v1_0h );
2536- const uint8x16_t v1_0hz = vuzp2q_u8 (v1_0l , v1_0h );
2537- const uint8x16_t v1_1lz = vuzp1q_u8 (v1_1l , v1_1h );
2538- const uint8x16_t v1_1hz = vuzp2q_u8 (v1_1l , v1_1h );
2539-
2540- const int8x16_t v1_0ls = vreinterpretq_s8_u8 (vsubq_u8 (v1_0lz , u128b ));
2541- const int8x16_t v1_0hs = vreinterpretq_s8_u8 (vsubq_u8 (v1_0hz , u128b ));
2542- const int8x16_t v1_1ls = vreinterpretq_s8_u8 (vsubq_u8 (v1_1lz , u128b ));
2543- const int8x16_t v1_1hs = vreinterpretq_s8_u8 (vsubq_u8 (v1_1hz , u128b ));
2616+ const int8x16_t v1_0ls = vuzp1q_s8 (v1_0l , v1_0h );
2617+ const int8x16_t v1_0hs = vuzp2q_s8 (v1_0l , v1_0h );
2618+ const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2619+ const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
25442620
25452621#if defined(__ARM_FEATURE_DOTPROD )
25462622 // dot product into int32x4_t
@@ -2578,14 +2654,102 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25782654 }
25792655
25802656 sumf = sum0 + sum1 ;
2657+ #elif defined(__AVX2__ )
2658+ // Initialize accumulator with zeros
2659+ __m256 acc = _mm256_setzero_ps ();
2660+
2661+ // Main loop
2662+ for (int i = 0 ; i < nb ; ++ i ) {
2663+ /* Compute combined scale for the block */
2664+ const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2665+
2666+ __m256i bx = bytesFromNibbles (x [i ].qs );
2667+
2668+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2669+ const __m256i off = _mm256_set1_epi8 ( 8 );
2670+ bx = _mm256_sub_epi8 ( bx , off );
2671+
2672+ __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2673+
2674+ // Get absolute values of x vectors
2675+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2676+
2677+ // Sign the values of the y vectors
2678+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2679+
2680+ // Perform multiplication and create 16-bit values
2681+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2682+
2683+ const __m256i ones = _mm256_set1_epi16 (1 );
2684+ __m256i xy_q = _mm256_madd_epi16 (ones , dot );
2685+
2686+ /* Convert to vectore of 8 int32_t to 8 floats */
2687+ __m256 q = _mm256_cvtepi32_ps ( xy_q );
2688+
2689+ /* Multiply q with scale and accumulate */
2690+ acc = _mm256_fmadd_ps ( d , q , acc );
2691+ }
2692+
2693+ // Return horizontal sum of the acc vector
2694+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
2695+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2696+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2697+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2698+
2699+ sumf = _mm_cvtss_f32 ( res );
2700+ #elif defined(__AVX__ )
2701+ // Initialize accumulator with zeros
2702+ __m256 acc = _mm256_setzero_ps ();
2703+
2704+ // Main loop
2705+ for (int i = 0 ; i < nb ; ++ i ) {
2706+ // Compute combined scale for the block
2707+ const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2708+
2709+ __m128i i32 [2 ];
2710+ for (int j = 0 ; j < 2 ; ++ j ) {
2711+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2712+ __m128i bx = bytesFromNibbles ( x [i ].qs + 8 * j );
2713+ __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 * j ));
2714+
2715+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2716+ const __m128i off = _mm_set1_epi8 ( 8 );
2717+ bx = _mm_sub_epi8 ( bx , off );
2718+
2719+ // Get absolute values of x vectors
2720+ const __m128i ax = _mm_sign_epi8 (bx , bx );
2721+
2722+ // Sign the values of the y vectors
2723+ const __m128i sy = _mm_sign_epi8 (by , bx );
2724+
2725+ // Perform multiplication and create 16-bit values
2726+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2727+
2728+ const __m128i ones = _mm_set1_epi16 (1 );
2729+ i32 [j ] = _mm_madd_epi16 (ones , dot );
2730+ }
2731+
2732+ // Convert int32_t to float
2733+ __m256 p = _mm256_cvtepi32_ps ( _mm256_set_m128i ( i32 [0 ], i32 [1 ] ));
2734+ // Apply the scale, and accumulate
2735+ acc = _mm256_add_ps (_mm256_mul_ps ( d , p ), acc );
2736+ }
2737+
2738+ // Return horizontal sum of the acc vector
2739+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
2740+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2741+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2742+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2743+
2744+ sumf = _mm_cvtss_f32 ( res );
25812745#else
25822746 // scalar
25832747 for (int i = 0 ; i < nb ; i ++ ) {
25842748 const float d0 = x [i ].d ;
25852749 const float d1 = y [i ].d ;
25862750
25872751 const uint8_t * restrict p0 = x [i ].qs ;
2588- const uint8_t * restrict p1 = y [i ].qs ;
2752+ const int8_t * restrict p1 = y [i ].qs ;
25892753
25902754 int sumi = 0 ;
25912755 for (int j = 0 ; j < QK /2 ; j ++ ) {
@@ -2594,10 +2758,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25942758 const int i0 = (int8_t ) (v0 & 0xf ) - 8 ;
25952759 const int i1 = (int8_t ) (v0 >> 4 ) - 8 ;
25962760
2597- const int i2 = (int ) p1 [2 * j + 0 ] - 128 ;
2598- const int i3 = (int ) p1 [2 * j + 1 ] - 128 ;
2599-
2600- /*printf("dot product: i0=%4d i1=%4d i2=%4d i3=%4d\n", i0, i1, i2, i3);*/
2761+ const int i2 = p1 [2 * j + 0 ];
2762+ const int i3 = p1 [2 * j + 1 ];
26012763
26022764 sumi += i0 * i2 + i1 * i3 ;
26032765 }
@@ -9923,7 +10085,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
992310085 cur = GGML_TYPE_SIZE [GGML_TYPE_F32 ]* (node -> src0 -> ne [0 ]* node -> src0 -> ne [1 ]);
992410086 } else
992510087#endif
9926- cur = GGML_TYPE_SIZE [GGML_TYPE_Q8_0 ]* ggml_nelements (node -> src1 )/GGML_BLCK_SIZE [GGML_TYPE_Q8_0 ];
10088+ {
10089+ cur = GGML_TYPE_SIZE [GGML_TYPE_Q8_0 ]* ggml_nelements (node -> src1 )/GGML_BLCK_SIZE [GGML_TYPE_Q8_0 ];
10090+ }
992710091 } else {
992810092 GGML_ASSERT (false);
992910093 }
0 commit comments