Skip to content

Commit 614c667

Browse files
authored
Add sparse marlin 2:4 gemm op (#733)
feat: add sparse marlin 2:4 kernel
1 parent aacaf9b commit 614c667

File tree

12 files changed

+2542
-1
lines changed

12 files changed

+2542
-1
lines changed

test/test_ops.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
run_tests,
1111
)
1212
from torch.testing._internal.optests import opcheck
13-
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5
13+
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
1414
from torchao.prototype.quant_llm import from_scaled_tc_fpx
15+
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
1516
import pytest
1617

1718
if is_fbcode():
@@ -302,5 +303,119 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
302303
test_utils=test_utils,
303304
)
304305

306+
307+
MARLIN_24_K_CHUNKS = [128]
308+
MARLIN_24_N_CHUNKS = [512]
309+
MNK_FACTORS = [
310+
(1, 1, 1),
311+
(1, 4, 8),
312+
(1, 7, 5),
313+
(13, 17, 67),
314+
(26, 37, 13),
315+
(67, 13, 11),
316+
]
317+
MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
318+
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
319+
320+
MARLIN_TEST_PARAMS = list(itertools.product(
321+
MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS,
322+
MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
323+
))
324+
325+
def _symmetric_quantize_with_ref(w: torch.Tensor, num_bits: int, group_size: int):
326+
orig_device = w.device
327+
size_k, size_n = w.shape
328+
329+
assert w.is_floating_point(), "w must be float"
330+
331+
if group_size == -1:
332+
group_size = size_k
333+
assert group_size <= size_k
334+
335+
max_q_val = 2**num_bits - 1
336+
half_q_val = (max_q_val + 1) // 2
337+
338+
# Reshape to [groupsize, -1]
339+
if group_size < size_k:
340+
w = w.reshape((-1, group_size, size_n))
341+
w = w.permute(1, 0, 2)
342+
w = w.reshape((group_size, -1))
343+
344+
# Compute scale for each group
345+
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
346+
s *= 2 / max_q_val # 2 => symmetric
347+
348+
# Quantize
349+
q_w = torch.round(w / s).int()
350+
q_w += half_q_val
351+
q_w = torch.clamp(q_w, 0, max_q_val)
352+
353+
# Compute ref (dequantized)
354+
w_ref = (q_w - half_q_val).half() * s
355+
356+
# Restore original shapes
357+
if group_size < size_k:
358+
359+
def reshape_w(w):
360+
w = w.reshape((group_size, -1, size_n))
361+
w = w.permute(1, 0, 2)
362+
w = w.reshape((size_k, size_n)).contiguous()
363+
return w
364+
365+
q_w = reshape_w(q_w)
366+
w_ref = reshape_w(w_ref)
367+
368+
s = s.reshape((-1, size_n)).contiguous()
369+
370+
return (
371+
w_ref.to(device=orig_device),
372+
q_w.to(device=orig_device),
373+
s.to(device=orig_device),
374+
)
375+
376+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
377+
@pytest.mark.parametrize("k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, ids=str)
378+
def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
379+
m_factor, n_factor, k_factor = mnk_factors
380+
381+
size_m = m_factor
382+
size_k = k_chunk * k_factor
383+
size_n = n_chunk * n_factor
384+
385+
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
386+
b_weight = torch.rand((size_k, size_n), dtype=torch.float16, device="cuda")
387+
388+
# Inject 2:4 sparsity
389+
w_24, _ = inject_24(b_weight, size_k, size_n)
390+
391+
# Symmetric quantize
392+
w_24_ref, q_w_24, scale = _symmetric_quantize_with_ref(w_24, num_bits, group_size)
393+
394+
# Obtains reference output
395+
output_ref = torch.matmul(a_input, w_24_ref)
396+
397+
# Packs to marlin 2:4
398+
marlin_24_q_w_comp, marlin_24_scale, meta = pack_to_marlin_24(q_w_24, scale, num_bits, group_size)
399+
workspace_24 = marlin_24_workspace(size_n)
400+
401+
fn_inputs = (
402+
a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
403+
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1],
404+
)
405+
output = torchao.ops.marlin_24_gemm(*fn_inputs)
406+
torch.cuda.synchronize()
407+
408+
max_diff = compute_max_diff(output, output_ref)
409+
assert max_diff < 0.04
410+
411+
# Performs opcheck
412+
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
413+
opcheck(
414+
torch.ops.torchao.marlin_24_gemm,
415+
fn_inputs,
416+
test_utils=test_utils,
417+
)
418+
419+
305420
if __name__ == "__main__":
306421
run_tests()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (C) 2024 Roberto Lopez Castro ([email protected]). All
3+
* Rights Reserved.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
namespace torchao {
21+
22+
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
23+
24+
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
25+
// for instance as inputs to tensor core operations. Consequently, all
26+
// corresponding index accesses must be compile-time constants, which is why we
27+
// extensively use `#pragma unroll` throughout the kernel code to guarantee
28+
// this.
29+
template <typename T, int n>
30+
struct Vec {
31+
T elems[n];
32+
__device__ T& operator[](int i) { return elems[i]; }
33+
};
34+
35+
template <int M_, int N_, int K_>
36+
struct ShapeBase {
37+
static constexpr int M = M_, N = N_, K = K_;
38+
};
39+
40+
using I4 = Vec<int, 4>;
41+
42+
// Matrix fragments for tensor core instructions; their precise layout is
43+
// documented here:
44+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
45+
using FragA = Vec<half2, 4>;
46+
using FragB = Vec<half2, 2>;
47+
using FragM = Vec<uint, 1>;
48+
using FragC = Vec<float, 4>;
49+
using FragS = Vec<half2, 1>; // quantization scales
50+
51+
} // namespace torchao

0 commit comments

Comments
 (0)