Skip to content

Commit 2761ed5

Browse files
manuelcandalesjainapurva
authored andcommitted
Introduce lowbit quantized linear MPS kernels
Differential Revision: D63342895 Pull Request resolved: #954
1 parent ca2b385 commit 2761ed5

File tree

17 files changed

+1701
-0
lines changed

17 files changed

+1701
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import Optional
2+
import os
3+
import yaml
4+
5+
torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT")
6+
assert torchao_root is not None, "TORCHAO_ROOT is not set"
7+
8+
MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps")
9+
10+
# Path to yaml file containing the list of .metal files to include
11+
METAL_YAML = os.path.join(MPS_DIR, "metal.yaml")
12+
13+
metal_files = set()
14+
with open(METAL_YAML, "r") as yamlf:
15+
metal_config = yaml.safe_load(yamlf)
16+
for op in metal_config:
17+
if "file" in op:
18+
metal_files.add(op["file"])
19+
metal_files = sorted(metal_files)
20+
21+
# Path to the folder containing the .metal files
22+
METAL_DIR = os.path.join(MPS_DIR, "metal")
23+
24+
# Output file where the generated code will be written
25+
OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h")
26+
27+
prefix = """/**
28+
* This file is generated by gen_metal_shader_lib.py
29+
*/
30+
31+
#ifdef ATEN
32+
using namespace at::native::mps;
33+
#else
34+
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
35+
#endif
36+
37+
static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT(
38+
"""
39+
40+
suffix = """
41+
)METAL_LOWBIT");
42+
"""
43+
44+
comment = """
45+
/**
46+
* Contents of {}
47+
*/
48+
49+
"""
50+
51+
with open(OUTPUT_FILE, "w") as outf:
52+
outf.write(prefix)
53+
for file in metal_files:
54+
with open(os.path.join(METAL_DIR, file), "r") as f:
55+
content = f.read()
56+
outf.write(comment.format(file))
57+
outf.write(content)
58+
outf.write("\n\n")
59+
outf.write(suffix)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
- func: int1mm
2+
file: divbit.metal
3+
4+
- func: int2mm
5+
file: divbit.metal
6+
7+
- func: int3mm
8+
file: int3mm.metal
9+
10+
- func: int4mm
11+
file: divbit.metal
12+
13+
- func: int5mm
14+
file: int5mm.metal
15+
16+
- func: int6mm
17+
file: int6mm.metal
18+
19+
- func: int7mm
20+
file: int7mm.metal
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#include <metal_stdlib>
2+
using namespace metal;
3+
4+
/**
5+
* LowBit Quantized Linear for bitwidths that are divisors of 8. Hence the name.
6+
*
7+
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
8+
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8)
9+
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
10+
* @param[outputData] M x N output tensor of floating point dtype (same as input)
11+
* @param[sizes] The sizes involved in the order: M, K, N
12+
*
13+
* Dispatched threads: N x M x 1
14+
*/
15+
template<typename T, unsigned nbit, unsigned groupSize>
16+
kernel void divbit_mm(
17+
constant T * A [[buffer(0)]],
18+
constant uchar * B [[buffer(1)]],
19+
constant T * scalesAndZeros [[buffer(2)]],
20+
device T * outputData [[buffer(3)]],
21+
constant uint3 & sizes [[buffer(4)]], // M, K, N
22+
uint2 thread_index [[thread_position_in_grid]]) {
23+
const uint K = sizes.y;
24+
const uint N = sizes.z;
25+
const uint m = thread_index.y; // 0..M-1
26+
const uint n = thread_index.x; // 0..N-1
27+
const uint32_t k_block = (K + groupSize - 1) / groupSize;
28+
constant T *A_ptr = A + m * K;
29+
constant uchar *B_ptr = B;
30+
31+
constexpr uint8_t zero_shift = 1 << (nbit - 1);
32+
constexpr uint8_t values_per_byte = 8 / nbit;
33+
constexpr uint8_t minimask = (1 << nbit) - 1;
34+
35+
float rc = 0.0;
36+
uint k = 0;
37+
for (uint32_t kb = 0; kb < k_block ; kb ++) {
38+
const T scale = scalesAndZeros[(kb * N + n) * 2 + 0];
39+
const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift);
40+
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) {
41+
const auto a_val = float(A_ptr[k]);
42+
uint8_t b_val = B_ptr[(n * K + k) / values_per_byte];
43+
uint8_t shift = nbit * (k % values_per_byte);
44+
uint8_t mask = minimask << shift;
45+
b_val = (b_val & mask) >> shift;
46+
rc += a_val * float(scale * T(b_val) + zero);
47+
}
48+
}
49+
outputData[m * N + n] = T(rc);
50+
}
51+
52+
#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \
53+
template \
54+
[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \
55+
kernel void divbit_mm<DTYPE, NBIT, GSIZE>( \
56+
constant DTYPE * A [[buffer(0)]], \
57+
constant uchar * B [[buffer(1)]], \
58+
constant DTYPE * scalesAndZeros [[buffer(2)]], \
59+
device DTYPE * outputData [[buffer(3)]], \
60+
constant uint3 & sizes [[buffer(4)]], \
61+
uint2 thread_index [[thread_position_in_grid]])
62+
63+
INSTANTIATE_DIVBIT_MM(1, float, 32);
64+
INSTANTIATE_DIVBIT_MM(1, half, 32);
65+
INSTANTIATE_DIVBIT_MM(1, float, 64);
66+
INSTANTIATE_DIVBIT_MM(1, half, 64);
67+
INSTANTIATE_DIVBIT_MM(1, float, 128);
68+
INSTANTIATE_DIVBIT_MM(1, half, 128);
69+
INSTANTIATE_DIVBIT_MM(1, float, 256);
70+
INSTANTIATE_DIVBIT_MM(1, half, 256);
71+
#if __METAL_VERSION__ >= 310
72+
INSTANTIATE_DIVBIT_MM(1, bfloat, 32);
73+
INSTANTIATE_DIVBIT_MM(1, bfloat, 64);
74+
INSTANTIATE_DIVBIT_MM(1, bfloat, 128);
75+
INSTANTIATE_DIVBIT_MM(1, bfloat, 256);
76+
#endif
77+
78+
INSTANTIATE_DIVBIT_MM(2, float, 32);
79+
INSTANTIATE_DIVBIT_MM(2, half, 32);
80+
INSTANTIATE_DIVBIT_MM(2, float, 64);
81+
INSTANTIATE_DIVBIT_MM(2, half, 64);
82+
INSTANTIATE_DIVBIT_MM(2, float, 128);
83+
INSTANTIATE_DIVBIT_MM(2, half, 128);
84+
INSTANTIATE_DIVBIT_MM(2, float, 256);
85+
INSTANTIATE_DIVBIT_MM(2, half, 256);
86+
#if __METAL_VERSION__ >= 310
87+
INSTANTIATE_DIVBIT_MM(2, bfloat, 32);
88+
INSTANTIATE_DIVBIT_MM(2, bfloat, 64);
89+
INSTANTIATE_DIVBIT_MM(2, bfloat, 128);
90+
INSTANTIATE_DIVBIT_MM(2, bfloat, 256);
91+
#endif
92+
93+
INSTANTIATE_DIVBIT_MM(4, float, 32);
94+
INSTANTIATE_DIVBIT_MM(4, half, 32);
95+
INSTANTIATE_DIVBIT_MM(4, float, 64);
96+
INSTANTIATE_DIVBIT_MM(4, half, 64);
97+
INSTANTIATE_DIVBIT_MM(4, float, 128);
98+
INSTANTIATE_DIVBIT_MM(4, half, 128);
99+
INSTANTIATE_DIVBIT_MM(4, float, 256);
100+
INSTANTIATE_DIVBIT_MM(4, half, 256);
101+
#if __METAL_VERSION__ >= 310
102+
INSTANTIATE_DIVBIT_MM(4, bfloat, 32);
103+
INSTANTIATE_DIVBIT_MM(4, bfloat, 64);
104+
INSTANTIATE_DIVBIT_MM(4, bfloat, 128);
105+
INSTANTIATE_DIVBIT_MM(4, bfloat, 256);
106+
#endif
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#include <metal_stdlib>
2+
using namespace metal;
3+
4+
/**
5+
* 3-Bit Quantized Linear.
6+
*
7+
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
8+
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
9+
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
10+
* @param[outputData] M x N output tensor of floating point dtype (same as input)
11+
* @param[sizes] The sizes involved in the order: M, K, N
12+
*
13+
* Dispatched threads: N x M x 1
14+
*/
15+
template<typename T, unsigned groupSize>
16+
kernel void int3pack_mm(
17+
constant T * A [[buffer(0)]],
18+
constant uchar * B [[buffer(1)]],
19+
constant T * scalesAndZeros [[buffer(2)]],
20+
device T * outputData [[buffer(3)]],
21+
constant uint3 & sizes [[buffer(4)]], // M, K, N
22+
uint2 thread_index [[thread_position_in_grid]]) {
23+
const uint K = sizes.y;
24+
const uint N = sizes.z;
25+
const uint m = thread_index.y; // 0..M-1
26+
const uint n = thread_index.x; // 0..N-1
27+
const uint32_t k_block = (K + groupSize - 1) / groupSize;
28+
constant T *A_ptr = A + m * K;
29+
constant uchar *B_ptr = B + n * 3 * K / 8;
30+
31+
float rc = 0.0;
32+
uint k = 0;
33+
for (uint32_t kb = 0; kb < k_block ; kb ++) {
34+
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
35+
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4);
36+
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
37+
const auto a_val0 = float(A_ptr[k + 0]);
38+
const auto a_val1 = float(A_ptr[k + 1]);
39+
const auto a_val2 = float(A_ptr[k + 2]);
40+
const auto a_val3 = float(A_ptr[k + 3]);
41+
const auto a_val4 = float(A_ptr[k + 4]);
42+
const auto a_val5 = float(A_ptr[k + 5]);
43+
const auto a_val6 = float(A_ptr[k + 6]);
44+
const auto a_val7 = float(A_ptr[k + 7]);
45+
46+
uchar b0 = B_ptr[3 * (k / 8) + 0];
47+
uchar b1 = B_ptr[3 * (k / 8) + 1];
48+
uchar b2 = B_ptr[3 * (k / 8) + 2];
49+
50+
uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3);
51+
uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2);
52+
uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4);
53+
uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6);
54+
55+
uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3);
56+
uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2);
57+
uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4);
58+
uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6);
59+
60+
rc += a_val0 * (scale * float(w_val0) + zero);
61+
rc += a_val1 * (scale * float(w_val1) + zero);
62+
rc += a_val2 * (scale * float(w_val2) + zero);
63+
rc += a_val3 * (scale * float(w_val3) + zero);
64+
rc += a_val4 * (scale * float(w_val4) + zero);
65+
rc += a_val5 * (scale * float(w_val5) + zero);
66+
rc += a_val6 * (scale * float(w_val6) + zero);
67+
rc += a_val7 * (scale * float(w_val7) + zero);
68+
}
69+
}
70+
outputData[m * N + n] = T(rc);
71+
}
72+
73+
#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \
74+
template \
75+
[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \
76+
kernel void int3pack_mm<DTYPE, GSIZE>( \
77+
constant DTYPE * A [[buffer(0)]], \
78+
constant uchar * B [[buffer(1)]], \
79+
constant DTYPE * scalesAndZeros [[buffer(2)]], \
80+
device DTYPE * outputData [[buffer(3)]], \
81+
constant uint3 & sizes [[buffer(4)]], \
82+
uint2 thread_index [[thread_position_in_grid]])
83+
84+
INSTANTIATE_INT3MM(float, 32);
85+
INSTANTIATE_INT3MM(half, 32);
86+
INSTANTIATE_INT3MM(float, 64);
87+
INSTANTIATE_INT3MM(half, 64);
88+
INSTANTIATE_INT3MM(float, 128);
89+
INSTANTIATE_INT3MM(half, 128);
90+
INSTANTIATE_INT3MM(float, 256);
91+
INSTANTIATE_INT3MM(half, 256);
92+
#if __METAL_VERSION__ >= 310
93+
INSTANTIATE_INT3MM(bfloat, 32);
94+
INSTANTIATE_INT3MM(bfloat, 64);
95+
INSTANTIATE_INT3MM(bfloat, 128);
96+
INSTANTIATE_INT3MM(bfloat, 256);
97+
#endif
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#include <metal_stdlib>
2+
using namespace metal;
3+
4+
/**
5+
* 5-Bit Quantized Linear.
6+
*
7+
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
8+
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8)
9+
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
10+
* @param[outputData] M x N output tensor of floating point dtype (same as input)
11+
* @param[sizes] The sizes involved in the order: M, K, N
12+
*
13+
* Dispatched threads: N x M x 1
14+
*/
15+
template<typename T, unsigned groupSize>
16+
kernel void int5pack_mm(
17+
constant T * A [[buffer(0)]],
18+
constant uchar * B [[buffer(1)]],
19+
constant T * scalesAndZeros [[buffer(2)]],
20+
device T * outputData [[buffer(3)]],
21+
constant uint3 & sizes [[buffer(4)]], // M, K, N
22+
uint2 thread_index [[thread_position_in_grid]]) {
23+
const uint K = sizes.y;
24+
const uint N = sizes.z;
25+
const uint m = thread_index.y; // 0..M-1
26+
const uint n = thread_index.x; // 0..N-1
27+
const uint32_t k_block = (K + groupSize - 1) / groupSize;
28+
constant T *A_ptr = A + m * K;
29+
constant uchar *B_ptr = B + n * 5 * K / 8;
30+
31+
float rc = 0.0;
32+
uint k = 0;
33+
for (uint32_t kb = 0; kb < k_block ; kb ++) {
34+
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
35+
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(16);
36+
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
37+
const auto a_val0 = float(A_ptr[k + 0]);
38+
const auto a_val1 = float(A_ptr[k + 1]);
39+
const auto a_val2 = float(A_ptr[k + 2]);
40+
const auto a_val3 = float(A_ptr[k + 3]);
41+
const auto a_val4 = float(A_ptr[k + 4]);
42+
const auto a_val5 = float(A_ptr[k + 5]);
43+
const auto a_val6 = float(A_ptr[k + 6]);
44+
const auto a_val7 = float(A_ptr[k + 7]);
45+
46+
uchar b0 = B_ptr[5 * (k / 8) + 0];
47+
uchar b1 = B_ptr[5 * (k / 8) + 1];
48+
uchar b2 = B_ptr[5 * (k / 8) + 2];
49+
uchar b3 = B_ptr[5 * (k / 8) + 3];
50+
uchar b4 = B_ptr[5 * (k / 8) + 4];
51+
52+
uchar w_val0 = ((b0 & 1) << 4) | (b1 & 15);
53+
uchar w_val1 = ((b0 & 2) << 3) | ((b1 & 240) >> 4);
54+
uchar w_val2 = ((b0 & 4) << 2) | (b2 & 15);
55+
uchar w_val3 = ((b0 & 8) << 1) | ((b2 & 240) >> 4);
56+
57+
uchar w_val4 = ((b0 & 16)) | (b3 & 15);
58+
uchar w_val5 = ((b0 & 32) >> 1) | ((b3 & 240) >> 4);
59+
uchar w_val6 = ((b0 & 64) >> 2) | (b4 & 15);
60+
uchar w_val7 = ((b0 & 128) >> 3) | ((b4 & 240) >> 4);
61+
62+
rc += a_val0 * (scale * float(w_val0) + zero);
63+
rc += a_val1 * (scale * float(w_val1) + zero);
64+
rc += a_val2 * (scale * float(w_val2) + zero);
65+
rc += a_val3 * (scale * float(w_val3) + zero);
66+
rc += a_val4 * (scale * float(w_val4) + zero);
67+
rc += a_val5 * (scale * float(w_val5) + zero);
68+
rc += a_val6 * (scale * float(w_val6) + zero);
69+
rc += a_val7 * (scale * float(w_val7) + zero);
70+
}
71+
}
72+
outputData[m * N + n] = T(rc);
73+
}
74+
75+
#define INSTANTIATE_INT5MM(DTYPE, GSIZE) \
76+
template \
77+
[[host_name("int5pack_mm_" #GSIZE "_" #DTYPE)]] \
78+
kernel void int5pack_mm<DTYPE, GSIZE>( \
79+
constant DTYPE * A [[buffer(0)]], \
80+
constant uchar * B [[buffer(1)]], \
81+
constant DTYPE * scalesAndZeros [[buffer(2)]], \
82+
device DTYPE * outputData [[buffer(3)]], \
83+
constant uint3 & sizes [[buffer(4)]], \
84+
uint2 thread_index [[thread_position_in_grid]])
85+
86+
INSTANTIATE_INT5MM(float, 32);
87+
INSTANTIATE_INT5MM(half, 32);
88+
INSTANTIATE_INT5MM(float, 64);
89+
INSTANTIATE_INT5MM(half, 64);
90+
INSTANTIATE_INT5MM(float, 128);
91+
INSTANTIATE_INT5MM(half, 128);
92+
INSTANTIATE_INT5MM(float, 256);
93+
INSTANTIATE_INT5MM(half, 256);
94+
#if __METAL_VERSION__ >= 310
95+
INSTANTIATE_INT5MM(bfloat, 32);
96+
INSTANTIATE_INT5MM(bfloat, 64);
97+
INSTANTIATE_INT5MM(bfloat, 128);
98+
INSTANTIATE_INT5MM(bfloat, 256);
99+
#endif

0 commit comments

Comments
 (0)