@@ -6,7 +6,8 @@ using namespace metal;
66 *
77 * @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
88 * @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8)
9- * @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
9+ * @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
10+ * @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
1011 * @param[outputData] M x N output tensor of floating point dtype (same as input)
1112 * @param[sizes] The sizes involved in the order: M, K, N
1213 *
@@ -16,9 +17,10 @@ template<typename T, unsigned nbit, unsigned groupSize>
1617kernel void divbit_mm (
1718 constant T * A [[buffer(0 )]],
1819 constant uchar * B [[buffer(1 )]],
19- constant T * scalesAndZeros [[buffer(2 )]],
20- device T * outputData [[buffer(3 )]],
21- constant uint3 & sizes [[buffer(4 )]], // M, K, N
20+ constant T * scales [[buffer(2 )]],
21+ constant T * zeros [[buffer(3 )]],
22+ device T * outputData [[buffer(4 )]],
23+ constant uint3 & sizes [[buffer(5 )]], // M, K, N
2224 uint2 thread_index [[thread_position_in_grid]]) {
2325 const uint K = sizes.y ;
2426 const uint N = sizes.z ;
@@ -35,29 +37,30 @@ kernel void divbit_mm(
3537 float rc = 0.0 ;
3638 uint k = 0 ;
3739 for (uint32_t kb = 0 ; kb < k_block ; kb ++) {
38- const T scale = scalesAndZeros[( kb * N + n) * 2 + 0 ] ;
39- const T zero = scalesAndZeros[( kb * N + n) * 2 + 1 ] - scale * T (zero_shift );
40+ const float scale = float (scales[ kb * N + n]) ;
41+ const float zero = float (zeros[ kb * N + n] );
4042 for (uint idx = 0 ; idx < groupSize && k < K; idx++, k++) {
4143 const auto a_val = float (A_ptr[k]);
4244 uint8_t b_val = B_ptr[(n * K + k) / values_per_byte];
4345 uint8_t shift = nbit * (k % values_per_byte);
4446 uint8_t mask = minimask << shift;
4547 b_val = (b_val & mask) >> shift;
46- rc += a_val * float (scale * T (b_val) + zero);
48+ rc += a_val * (scale * float (b_val) + zero);
4749 }
4850 }
4951 outputData[m * N + n] = T (rc);
5052}
5153
52- #define INSTANTIATE_DIVBIT_MM (NBIT, DTYPE, GSIZE ) \
54+ #define INSTANTIATE_DIVBIT_MM (NBIT, DTYPE, GSIZE ) \
5355template \
5456[[host_name(" int" #NBIT " pack_mm_" #GSIZE " _" #DTYPE)]] \
5557kernel void divbit_mm<DTYPE, NBIT, GSIZE>( \
5658 constant DTYPE * A [[buffer(0 )]], \
5759 constant uchar * B [[buffer(1 )]], \
58- constant DTYPE * scalesAndZeros [[buffer(2 )]], \
59- device DTYPE * outputData [[buffer(3 )]], \
60- constant uint3 & sizes [[buffer(4 )]], \
60+ constant DTYPE * scales [[buffer(2 )]], \
61+ constant DTYPE * zeros [[buffer(3 )]], \
62+ device DTYPE * outputData [[buffer(4 )]], \
63+ constant uint3 & sizes [[buffer(5 )]], \
6164 uint2 thread_index [[thread_position_in_grid]])
6265
6366INSTANTIATE_DIVBIT_MM (1 , float , 32 );
0 commit comments