Skip to content

Commit 0cb7a06

Browse files
authored
opencl: add q8_0 mm support (ggml-org#16469)
* opencl: add mm_q8_0_f32 * opencl: fix data loading for incomplete tile * opencl: use q8_0 mm for larger matrix * opencl: add some tests to cover the path
1 parent d93f843 commit 0cb7a06

File tree

6 files changed

+271
-19
lines changed

6 files changed

+271
-19
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ set(GGML_OPENCL_KERNELS
9393
mul_mv_id_mxfp4_f32_flat
9494
mul_mm_f32_f32_l4_lm
9595
mul_mm_f16_f32_l4_lm
96+
mul_mm_q8_0_f32_l4_lm
9697
mul
9798
norm
9899
relu

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ struct ggml_backend_opencl_context {
408408
cl_program program_mul_mv_id_mxfp4_f32_flat;
409409
cl_program program_mul_mm_f32_f32_l4_lm;
410410
cl_program program_mul_mm_f16_f32_l4_lm;
411+
cl_program program_mul_mm_q8_0_f32_l4_lm;
411412

412413
cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
413414
cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
@@ -480,6 +481,7 @@ struct ggml_backend_opencl_context {
480481
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
481482
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
482483
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
484+
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
483485

484486
std::vector<ProfilingInfo> profiling_info;
485487

@@ -1191,6 +1193,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
11911193
GGML_LOG_CONT(".");
11921194
}
11931195

1196+
// mul_mm_q8_0_f32_l4_lm
1197+
{
1198+
#ifdef GGML_OPENCL_EMBED_KERNELS
1199+
const std::string kernel_src {
1200+
#include "mul_mm_q8_0_f32_l4_lm.cl.h"
1201+
};
1202+
#else
1203+
const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl");
1204+
#endif
1205+
backend_ctx->program_mul_mm_q8_0_f32_l4_lm =
1206+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1207+
1208+
CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err));
1209+
GGML_LOG_CONT(".");
1210+
}
1211+
11941212
// mul
11951213
{
11961214
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -6961,6 +6979,44 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
69616979
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
69626980
return;
69636981
}
6982+
case GGML_TYPE_Q8_0: {
6983+
if (ne11 < 32) {
6984+
break;
6985+
}
6986+
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
6987+
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
6988+
6989+
int batch_stride_a = ne00*ne01;
6990+
int batch_stride_b = ne10*ne11;
6991+
int batch_stride_d = ne0*ne1;
6992+
6993+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q));
6994+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d));
6995+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
6996+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
6997+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
6998+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
6999+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
7000+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
7001+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
7002+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
7003+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
7004+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
7005+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
7006+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
7007+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
7008+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
7009+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
7010+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
7011+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
7012+
7013+
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
7014+
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
7015+
size_t local_work_size[] = {(size_t)nth0, 1, 1};
7016+
7017+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
7018+
return;
7019+
}
69647020
default:
69657021
break;
69667022
}

ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
7979

8080
for (int block = 0; block < ne00; block += BK) {
8181
for (int l = 0; l < BM; l += loadstride_a) {
82+
if (loadc_a + l < ne01) {
8283
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
83-
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
84-
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
85-
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
86-
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
84+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
85+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
86+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
87+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
88+
} else {
89+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0h;
90+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0h;
91+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0h;
92+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0h;
93+
}
8794
}
8895

8996
for (int l = 0; l < BN; l += loadstride_b) {
90-
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
91-
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
92-
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
93-
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
94-
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
97+
if (loadc_b + l < ne11) {
98+
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
99+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
100+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
101+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
102+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
103+
} else {
104+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0h;
105+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0h;
106+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0h;
107+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0h;
108+
}
95109
}
96110

97111
barrier(CLK_LOCAL_MEM_FENCE);

ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
7979

8080
for (int block = 0; block < ne00; block += BK) {
8181
for (int l = 0; l < BM; l += loadstride_a) {
82-
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
83-
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
84-
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
85-
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
86-
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
82+
if (loadc_a + l < ne01) {
83+
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
84+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
85+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
86+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
87+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
88+
} else {
89+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
90+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
91+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
92+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
93+
}
8794
}
8895

8996
for (int l = 0; l < BN; l += loadstride_b) {
90-
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
91-
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
92-
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
93-
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
94-
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
97+
if (loadc_b + l < ne11) {
98+
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
99+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
100+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
101+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
102+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
103+
} else {
104+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
105+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
106+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
107+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
108+
}
95109
}
96110

97111
barrier(CLK_LOCAL_MEM_FENCE);
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#define LOAD_VEC_A 4
4+
#define LOAD_VEC_B 4
5+
6+
#define BM 64
7+
#define BN 64
8+
#define BK 32
9+
#define TM 4
10+
#define TN 8
11+
12+
kernel void kernel_mul_mm_q8_0_f32_l4_lm(
13+
global char4 * src0_q,
14+
global half * src0_d,
15+
global float4 * src1,
16+
ulong offset1,
17+
global float * dst,
18+
ulong offsetd,
19+
20+
int ne00,
21+
int ne01,
22+
int ne02,
23+
int ne11,
24+
int ne12,
25+
26+
int stride_a,
27+
int stride_b,
28+
int stride_d,
29+
30+
int batch_stride_a,
31+
int batch_stride_b,
32+
int batch_stride_d,
33+
34+
int r2,
35+
int r3
36+
) {
37+
src1 = (global float4*)((global char*)src1 + offset1);
38+
dst = (global float *)((global char*)dst + offsetd);
39+
40+
local float buf_a[BM * BK];
41+
local float buf_b[BN * BK];
42+
43+
const int batch_idx = get_global_id(2);
44+
45+
const int i13 = batch_idx / ne12;
46+
const int i12 = batch_idx % ne12;
47+
48+
const int i03 = i13 / r3;
49+
const int i02 = i12 / r2;
50+
51+
const int batch_idx_a = i03 * ne02 + i02;
52+
53+
const int ir = get_group_id(0);
54+
const int ic = get_group_id(1);
55+
56+
const int tid = get_local_id(0);
57+
const int th_r = tid % (BM / TM);
58+
const int th_c = tid / (BM / TM);
59+
60+
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
61+
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
62+
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
63+
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
64+
65+
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
66+
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
67+
68+
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
69+
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
70+
71+
float sums[TM * TN];
72+
float cache_a[TM];
73+
float cache_b[TN];
74+
75+
for (int i = 0; i < TM * TN; i++) {
76+
sums[i] = 0.0f;
77+
}
78+
79+
for (int block = 0; block < ne00; block += BK) {
80+
for (int l = 0; l < BM; l += loadstride_a) {
81+
if (loadc_a + l < ne01) {
82+
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
83+
int ib = idx / 8;
84+
int iqs = idx % 8;
85+
86+
float d = (float)src0_d[ib];
87+
global char4 * qs = src0_q + ib*8 + iqs;
88+
char4 q = *qs;
89+
float4 v = convert_float4(q)*d;
90+
91+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0;
92+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1;
93+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2;
94+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3;
95+
} else {
96+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
97+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
98+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
99+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
100+
}
101+
}
102+
103+
for (int l = 0; l < BN; l += loadstride_b) {
104+
if (loadc_b + l < ne11) {
105+
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
106+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
107+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
108+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
109+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
110+
} else {
111+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
112+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
113+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
114+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
115+
}
116+
}
117+
118+
barrier(CLK_LOCAL_MEM_FENCE);
119+
120+
pos_a += BK / LOAD_VEC_A;
121+
pos_b += BK / LOAD_VEC_B;
122+
123+
for (int i = 0; i < BK; i++) {
124+
for (int j = 0; j < TM; j++) {
125+
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
126+
}
127+
128+
for (int j = 0; j < TN; j++) {
129+
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
130+
}
131+
132+
for (int cc = 0; cc < TN; cc++) {
133+
for (int cr = 0; cr < TM; cr++) {
134+
const int sums_idx = cc*TM + cr;
135+
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
136+
}
137+
}
138+
}
139+
barrier(CLK_LOCAL_MEM_FENCE);
140+
}
141+
142+
const int dr = ir * BM + th_r * TM;
143+
const int dc = ic * BN + th_c * TN;
144+
145+
const int offsets = batch_idx * batch_stride_d;
146+
147+
for (int cc = 0; cc < TN; cc++) {
148+
for (int cr = 0; cr < TM; cr++) {
149+
if (dr + cr < ne01 && dc + cc < ne11) {
150+
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
151+
}
152+
}
153+
}
154+
}

tests/test-backend-ops.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6365,6 +6365,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
63656365
}
63666366
}
63676367

6368+
#if 0
6369+
{
6370+
// Test paths in OpenCL
6371+
std::vector<int> ns = {32, 64, 128, 256, 512, 1024, 4096};
6372+
std::vector<int> ks = {896, 1536, 4096};
6373+
for (auto n : ns) {
6374+
for (auto k : ks) {
6375+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, n, k, {1, 1}, {1, 1}));
6376+
}
6377+
}
6378+
}
6379+
#endif
6380+
63686381
#if 1
63696382
for (ggml_type type_a : base_types) {
63706383
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {

0 commit comments

Comments
 (0)