Skip to content

Commit 657ebbb

Browse files
metal lowbit kernels: split scales and zero points
Differential Revision: D65232787 Pull Request resolved: #1202
1 parent a827d04 commit 657ebbb

File tree

9 files changed

+134
-97
lines changed

9 files changed

+134
-97
lines changed

torchao/experimental/kernels/mps/metal/divbit.metal

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using namespace metal;
66
*
77
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
88
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8)
9-
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
9+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
10+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
1011
* @param[outputData] M x N output tensor of floating point dtype (same as input)
1112
* @param[sizes] The sizes involved in the order: M, K, N
1213
*
@@ -16,9 +17,10 @@ template<typename T, unsigned nbit, unsigned groupSize>
1617
kernel void divbit_mm(
1718
constant T * A [[buffer(0)]],
1819
constant uchar * B [[buffer(1)]],
19-
constant T * scalesAndZeros [[buffer(2)]],
20-
device T * outputData [[buffer(3)]],
21-
constant uint3 & sizes [[buffer(4)]], // M, K, N
20+
constant T * scales [[buffer(2)]],
21+
constant T * zeros [[buffer(3)]],
22+
device T * outputData [[buffer(4)]],
23+
constant uint3 & sizes [[buffer(5)]], // M, K, N
2224
uint2 thread_index [[thread_position_in_grid]]) {
2325
const uint K = sizes.y;
2426
const uint N = sizes.z;
@@ -35,29 +37,30 @@ kernel void divbit_mm(
3537
float rc = 0.0;
3638
uint k = 0;
3739
for (uint32_t kb = 0; kb < k_block ; kb ++) {
38-
const T scale = scalesAndZeros[(kb * N + n) * 2 + 0];
39-
const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift);
40+
const float scale = float(scales[kb * N + n]);
41+
const float zero = float(zeros[kb * N + n]);
4042
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) {
4143
const auto a_val = float(A_ptr[k]);
4244
uint8_t b_val = B_ptr[(n * K + k) / values_per_byte];
4345
uint8_t shift = nbit * (k % values_per_byte);
4446
uint8_t mask = minimask << shift;
4547
b_val = (b_val & mask) >> shift;
46-
rc += a_val * float(scale * T(b_val) + zero);
48+
rc += a_val * (scale * float(b_val) + zero);
4749
}
4850
}
4951
outputData[m * N + n] = T(rc);
5052
}
5153

52-
#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \
54+
#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \
5355
template \
5456
[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \
5557
kernel void divbit_mm<DTYPE, NBIT, GSIZE>( \
5658
constant DTYPE * A [[buffer(0)]], \
5759
constant uchar * B [[buffer(1)]], \
58-
constant DTYPE * scalesAndZeros [[buffer(2)]], \
59-
device DTYPE * outputData [[buffer(3)]], \
60-
constant uint3 & sizes [[buffer(4)]], \
60+
constant DTYPE * scales [[buffer(2)]], \
61+
constant DTYPE * zeros [[buffer(3)]], \
62+
device DTYPE * outputData [[buffer(4)]], \
63+
constant uint3 & sizes [[buffer(5)]], \
6164
uint2 thread_index [[thread_position_in_grid]])
6265

6366
INSTANTIATE_DIVBIT_MM(1, float, 32);

torchao/experimental/kernels/mps/metal/int3mm.metal

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using namespace metal;
66
*
77
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
88
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
9-
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
9+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
10+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
1011
* @param[outputData] M x N output tensor of floating point dtype (same as input)
1112
* @param[sizes] The sizes involved in the order: M, K, N
1213
*
@@ -16,9 +17,10 @@ template<typename T, unsigned groupSize>
1617
kernel void int3pack_mm(
1718
constant T * A [[buffer(0)]],
1819
constant uchar * B [[buffer(1)]],
19-
constant T * scalesAndZeros [[buffer(2)]],
20-
device T * outputData [[buffer(3)]],
21-
constant uint3 & sizes [[buffer(4)]], // M, K, N
20+
constant T * scales [[buffer(2)]],
21+
constant T * zeros [[buffer(3)]],
22+
device T * outputData [[buffer(4)]],
23+
constant uint3 & sizes [[buffer(5)]], // M, K, N
2224
uint2 thread_index [[thread_position_in_grid]]) {
2325
const uint K = sizes.y;
2426
const uint N = sizes.z;
@@ -31,8 +33,8 @@ kernel void int3pack_mm(
3133
float rc = 0.0;
3234
uint k = 0;
3335
for (uint32_t kb = 0; kb < k_block ; kb ++) {
34-
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
35-
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4);
36+
const float scale = float(scales[kb * N + n]);
37+
const float zero = float(zeros[kb * N + n]);
3638
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
3739
const auto a_val0 = float(A_ptr[k + 0]);
3840
const auto a_val1 = float(A_ptr[k + 1]);
@@ -76,9 +78,10 @@ template \
7678
kernel void int3pack_mm<DTYPE, GSIZE>( \
7779
constant DTYPE * A [[buffer(0)]], \
7880
constant uchar * B [[buffer(1)]], \
79-
constant DTYPE * scalesAndZeros [[buffer(2)]], \
80-
device DTYPE * outputData [[buffer(3)]], \
81-
constant uint3 & sizes [[buffer(4)]], \
81+
constant DTYPE * scales [[buffer(2)]], \
82+
constant DTYPE * zeros [[buffer(3)]], \
83+
device DTYPE * outputData [[buffer(4)]], \
84+
constant uint3 & sizes [[buffer(5)]], \
8285
uint2 thread_index [[thread_position_in_grid]])
8386

8487
INSTANTIATE_INT3MM(float, 32);

torchao/experimental/kernels/mps/metal/int5mm.metal

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using namespace metal;
66
*
77
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
88
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8)
9-
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
9+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
10+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
1011
* @param[outputData] M x N output tensor of floating point dtype (same as input)
1112
* @param[sizes] The sizes involved in the order: M, K, N
1213
*
@@ -16,9 +17,10 @@ template<typename T, unsigned groupSize>
1617
kernel void int5pack_mm(
1718
constant T * A [[buffer(0)]],
1819
constant uchar * B [[buffer(1)]],
19-
constant T * scalesAndZeros [[buffer(2)]],
20-
device T * outputData [[buffer(3)]],
21-
constant uint3 & sizes [[buffer(4)]], // M, K, N
20+
constant T * scales [[buffer(2)]],
21+
constant T * zeros [[buffer(3)]],
22+
device T * outputData [[buffer(4)]],
23+
constant uint3 & sizes [[buffer(5)]], // M, K, N
2224
uint2 thread_index [[thread_position_in_grid]]) {
2325
const uint K = sizes.y;
2426
const uint N = sizes.z;
@@ -31,8 +33,8 @@ kernel void int5pack_mm(
3133
float rc = 0.0;
3234
uint k = 0;
3335
for (uint32_t kb = 0; kb < k_block ; kb ++) {
34-
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
35-
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(16);
36+
const float scale = float(scales[kb * N + n]);
37+
const float zero = float(zeros[kb * N + n]);
3638
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
3739
const auto a_val0 = float(A_ptr[k + 0]);
3840
const auto a_val1 = float(A_ptr[k + 1]);
@@ -78,9 +80,10 @@ template \
7880
kernel void int5pack_mm<DTYPE, GSIZE>( \
7981
constant DTYPE * A [[buffer(0)]], \
8082
constant uchar * B [[buffer(1)]], \
81-
constant DTYPE * scalesAndZeros [[buffer(2)]], \
82-
device DTYPE * outputData [[buffer(3)]], \
83-
constant uint3 & sizes [[buffer(4)]], \
83+
constant DTYPE * scales [[buffer(2)]], \
84+
constant DTYPE * zeros [[buffer(3)]], \
85+
device DTYPE * outputData [[buffer(4)]], \
86+
constant uint3 & sizes [[buffer(5)]], \
8487
uint2 thread_index [[thread_position_in_grid]])
8588

8689
INSTANTIATE_INT5MM(float, 32);

torchao/experimental/kernels/mps/metal/int6mm.metal

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using namespace metal;
66
*
77
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
88
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (6 * K / 8)
9-
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
9+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
10+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
1011
* @param[outputData] M x N output tensor of floating point dtype (same as input)
1112
* @param[sizes] The sizes involved in the order: M, K, N
1213
*
@@ -16,9 +17,10 @@ template<typename T, unsigned groupSize>
1617
kernel void int6pack_mm(
1718
constant T * A [[buffer(0)]],
1819
constant uchar * B [[buffer(1)]],
19-
constant T * scalesAndZeros [[buffer(2)]],
20-
device T * outputData [[buffer(3)]],
21-
constant uint3 & sizes [[buffer(4)]], // M, K, N
20+
constant T * scales [[buffer(2)]],
21+
constant T * zeros [[buffer(3)]],
22+
device T * outputData [[buffer(4)]],
23+
constant uint3 & sizes [[buffer(5)]], // M, K, N
2224
uint2 thread_index [[thread_position_in_grid]]) {
2325
const uint K = sizes.y;
2426
const uint N = sizes.z;
@@ -31,8 +33,8 @@ kernel void int6pack_mm(
3133
float rc = 0.0;
3234
uint k = 0;
3335
for (uint32_t kb = 0; kb < k_block ; kb ++) {
34-
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
35-
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(32);
36+
const float scale = float(scales[kb * N + n]);
37+
const float zero = float(zeros[kb * N + n]);
3638
for(uint idx = 0; idx < groupSize && k < K; idx+=4, k+=4) {
3739
const auto a_val0 = float(A_ptr[k + 0]);
3840
const auto a_val1 = float(A_ptr[k + 1]);
@@ -63,9 +65,10 @@ template \
6365
kernel void int6pack_mm<DTYPE, GSIZE>( \
6466
constant DTYPE * A [[buffer(0)]], \
6567
constant uchar * B [[buffer(1)]], \
66-
constant DTYPE * scalesAndZeros [[buffer(2)]], \
67-
device DTYPE * outputData [[buffer(3)]], \
68-
constant uint3 & sizes [[buffer(4)]], \
68+
constant DTYPE * scales [[buffer(2)]], \
69+
constant DTYPE * zeros [[buffer(3)]], \
70+
device DTYPE * outputData [[buffer(4)]], \
71+
constant uint3 & sizes [[buffer(5)]], \
6972
uint2 thread_index [[thread_position_in_grid]])
7073

7174
INSTANTIATE_INT6MM(float, 32);

torchao/experimental/kernels/mps/metal/int7mm.metal

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using namespace metal;
66
*
77
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
88
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (7 * K / 8)
9-
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
9+
* @param[scales] 2D tensor containg the scales for each group. Expected shape is #groups x N
10+
* @param[zeros] 2D tensor containg the zero points for each group. Expected shape is #groups x N
1011
* @param[outputData] M x N output tensor of floating point dtype (same as input)
1112
* @param[sizes] The sizes involved in the order: M, K, N
1213
*
@@ -16,9 +17,10 @@ template<typename T, unsigned groupSize>
1617
kernel void int7pack_mm(
1718
constant T * A [[buffer(0)]],
1819
constant uchar * B [[buffer(1)]],
19-
constant T * scalesAndZeros [[buffer(2)]],
20-
device T * outputData [[buffer(3)]],
21-
constant uint3 & sizes [[buffer(4)]], // M, K, N
20+
constant T * scales [[buffer(2)]],
21+
constant T * zeros [[buffer(3)]],
22+
device T * outputData [[buffer(4)]],
23+
constant uint3 & sizes [[buffer(5)]], // M, K, N
2224
uint2 thread_index [[thread_position_in_grid]]) {
2325
const uint K = sizes.y;
2426
const uint N = sizes.z;
@@ -31,8 +33,8 @@ kernel void int7pack_mm(
3133
float rc = 0.0;
3234
uint k = 0;
3335
for (uint32_t kb = 0; kb < k_block ; kb ++) {
34-
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
35-
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(64);
36+
const float scale = float(scales[kb * N + n]);
37+
const float zero = float(zeros[kb * N + n]);
3638
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
3739
const auto a_val0 = float(A_ptr[k + 0]);
3840
const auto a_val1 = float(A_ptr[k + 1]);
@@ -80,9 +82,10 @@ template \
8082
kernel void int7pack_mm<DTYPE, GSIZE>( \
8183
constant DTYPE * A [[buffer(0)]], \
8284
constant uchar * B [[buffer(1)]], \
83-
constant DTYPE * scalesAndZeros [[buffer(2)]], \
84-
device DTYPE * outputData [[buffer(3)]], \
85-
constant uint3 & sizes [[buffer(4)]], \
85+
constant DTYPE * scales [[buffer(2)]], \
86+
constant DTYPE * zeros [[buffer(3)]], \
87+
device DTYPE * outputData [[buffer(4)]], \
88+
constant uint3 & sizes [[buffer(5)]], \
8689
uint2 thread_index [[thread_position_in_grid]])
8790

8891
INSTANTIATE_INT7MM(float, 32);

torchao/experimental/kernels/mps/src/lowbit.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ using DispatchFn =
8888
inline void linear_lowbit_quant_weights_mps_impl(
8989
id<MTLBuffer> a_buf,
9090
id<MTLBuffer> b_buf,
91-
id<MTLBuffer> sz_buf,
91+
id<MTLBuffer> s_buf,
92+
id<MTLBuffer> z_buf,
9293
id<MTLBuffer> out_buf,
9394
int32_t M,
9495
int32_t K,
@@ -111,11 +112,12 @@ inline void linear_lowbit_quant_weights_mps_impl(
111112
[computeEncoder setComputePipelineState:cpl];
112113
[computeEncoder setBuffer:a_buf offset:0 atIndex:0];
113114
[computeEncoder setBuffer:b_buf offset:0 atIndex:1];
114-
[computeEncoder setBuffer:sz_buf offset:0 atIndex:2];
115-
[computeEncoder setBuffer:out_buf offset:0 atIndex:3];
115+
[computeEncoder setBuffer:s_buf offset:0 atIndex:2];
116+
[computeEncoder setBuffer:z_buf offset:0 atIndex:3];
117+
[computeEncoder setBuffer:out_buf offset:0 atIndex:4];
116118
[computeEncoder setBytes:sizes.data()
117119
length:sizeof(uint32_t) * sizes.size()
118-
atIndex:4];
120+
atIndex:5];
119121
dispatch_fn(computeEncoder, maxThreadsPerGroup, M, N, K);
120122
finalize_block(mpsStream);
121123
}
@@ -128,7 +130,8 @@ void linear_lowbit_quant_weights_mps(
128130
id<MTLBuffer> a_buf,
129131
id<MTLBuffer> b_buf,
130132
int64_t qGroupSize,
131-
id<MTLBuffer> sz_buf,
133+
id<MTLBuffer> s_buf,
134+
id<MTLBuffer> z_buf,
132135
id<MTLBuffer> out_buf,
133136
int32_t M,
134137
int32_t K,
@@ -143,7 +146,8 @@ void linear_lowbit_quant_weights_mps(
143146
return linear_lowbit_quant_weights_mps_impl(
144147
a_buf,
145148
b_buf,
146-
sz_buf,
149+
s_buf,
150+
z_buf,
147151
out_buf,
148152
M,
149153
K,

0 commit comments

Comments
 (0)