Skip to content

Commit bbbc915

Browse files
committed
support Marlin W4A8 GEMM
1 parent 2761917 commit bbbc915

File tree

15 files changed

+4105
-482
lines changed

15 files changed

+4105
-482
lines changed

benchmarks/benchmark_marlin_qqq.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
import pandas as pd
3+
from torchao.utils import benchmark_torch_function_in_microseconds
4+
from torchao.ops import marlin_qqq_gemm
5+
from torchao.quantization.marlin_qqq import marlin_qqq_workspace, pack_to_marlin_qqq
6+
from tqdm import tqdm
7+
8+
9+
def get_problem(m, n, k, groupsize=-1):
10+
if groupsize == -1:
11+
groupsize = k
12+
dev = torch.device("cuda")
13+
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
14+
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)
15+
16+
A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
17+
B = torch.randint(low=-(2**31), high=2**31, size=(k, n), device=dev)
18+
s_tok = torch.ones((m, 1), dtype=torch.float, device=dev)
19+
if groupsize == k:
20+
s_group = torch.tensor([], dtype=torch.half, device=dev)
21+
else:
22+
s_group = torch.ones((k // groupsize, n), dtype=torch.half, device=dev)
23+
s_channel = torch.ones((1, n), dtype=torch.float, device=dev)
24+
B, s_group, s_channel = pack_to_marlin_qqq(
25+
B, s_group, s_channel, num_bits=4, group_size=group_size
26+
)
27+
qqq_workspace = marlin_qqq_workspace(n)
28+
return A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace
29+
30+
31+
def benchmark(m: int, k: int, n: int, group_size: int):
32+
A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace = get_problem(
33+
m, n, k, group_size
34+
)
35+
36+
fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
37+
marlin_qqq_w4a8_time = benchmark_torch_function_in_microseconds(
38+
marlin_qqq_gemm, A, B, s_tok, s_channel, s_group, qqq_workspace, m, n, k
39+
)
40+
41+
return {
42+
"m": m,
43+
"k": k,
44+
"n": n,
45+
"group_size": group_size,
46+
"fp16_latency (ms)": fp16_time,
47+
"marlin_qqq_w4a8_latency (ms)": marlin_qqq_w4a8_time,
48+
"speedup (d/s)": fp16_time / marlin_qqq_w4a8_time,
49+
}
50+
51+
52+
if __name__ == "__main__":
53+
k_vals = (8192, 8192, 8192, 28672)
54+
n_vals = (8192, 10240, 57344, 8192)
55+
56+
results = []
57+
for group_size in tqdm([-1, 128]):
58+
for m in tqdm([1 << i for i in range(10)]):
59+
for n, k in zip(n_vals, k_vals):
60+
results.append(benchmark(m, k, n, group_size))
61+
62+
df = pd.DataFrame(results)
63+
df.to_csv("marlin_qqq_w4a8_llm_benchmark_results.csv", index=False)
64+
print(df.to_markdown(index=False))
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import torch
2+
import copy
3+
import pytest
4+
5+
from torch import nn
6+
from torch.testing._internal.common_utils import TestCase, run_tests
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
8+
from torchao.dtypes import MarlinQQQLayout
9+
from torchao.quantization.quant_api import (
10+
quantize_,
11+
int8_dynamic_activation_int4_weight,
12+
)
13+
from torchao.quantization.marlin_qqq import (
14+
pack_to_marlin_qqq,
15+
unpack_from_marlin_qqq,
16+
)
17+
from torchao.quantization.quant_primitives import (
18+
choose_qparams_and_quantize_affine_qqq,
19+
MappingType,
20+
)
21+
22+
23+
class MarlinQQQ(TestCase):
24+
def setUp(self):
25+
super().setUp()
26+
torch.manual_seed(0)
27+
28+
self.input = torch.randn((32, 16, 4096), dtype=torch.float16, device="cuda")
29+
self.model = (
30+
nn.Sequential(
31+
nn.Linear(4096, 21504),
32+
nn.Linear(21504, 4096),
33+
nn.ReLU(),
34+
nn.Linear(4096, 21504),
35+
nn.Linear(21504, 4096),
36+
)
37+
.half()
38+
.cuda()
39+
)
40+
41+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
42+
def test_marlin_qqq(self):
43+
output_ref = self.model(self.input)
44+
for group_size in [-1, 128]:
45+
modelq = copy.deepcopy(self.model)
46+
quantize_(
47+
modelq,
48+
int8_dynamic_activation_int4_weight(
49+
group_size=group_size,
50+
mapping_type=MappingType.SYMMETRIC,
51+
input_mapping_type=MappingType.SYMMETRIC,
52+
layout=MarlinQQQLayout(),
53+
),
54+
)
55+
output = modelq(self.input)
56+
57+
assert torch.allclose(
58+
output, output_ref, atol=1e-1
59+
), "Results are not close"
60+
61+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
62+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
63+
def test_marlin_qqq_compile(self):
64+
model_copy = copy.deepcopy(self.model)
65+
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True)
66+
output_ref = model_copy(self.input)
67+
68+
for group_size in [-1, 128]:
69+
modelq = copy.deepcopy(self.model)
70+
quantize_(
71+
modelq,
72+
int8_dynamic_activation_int4_weight(
73+
group_size=group_size,
74+
mapping_type=MappingType.SYMMETRIC,
75+
input_mapping_type=MappingType.SYMMETRIC,
76+
layout=MarlinQQQLayout(),
77+
),
78+
)
79+
modelq.forward = torch.compile(modelq.forward, fullgraph=True)
80+
output = modelq(self.input)
81+
82+
assert torch.allclose(
83+
output, output_ref, atol=1e-1
84+
), "Results are not close"
85+
86+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
87+
def test_pack_unpack_equivalence(self):
88+
num_bits = 4
89+
shape = (11008, 4096)
90+
mapping_type = MappingType.SYMMETRIC
91+
92+
w = torch.rand(shape, dtype=torch.float16, device="cuda")
93+
94+
for group_size in [-1, 128]:
95+
# Quantize weights
96+
q_w, s_group, s_channel = choose_qparams_and_quantize_affine_qqq(
97+
w, mapping_type, num_bits, group_size
98+
)
99+
100+
q_w = q_w.t()
101+
s_group = s_group.t()
102+
s_channel = s_channel.t()
103+
104+
# Test pack/unpack equivalence
105+
q_w_comp, packed_s_group, packed_s_channel = pack_to_marlin_qqq(
106+
q_w, s_group, s_channel, num_bits, group_size
107+
)
108+
unpacked_q_w, unpacked_s_group, unpacked_s_channel = unpack_from_marlin_qqq(
109+
q_w_comp,
110+
packed_s_group,
111+
packed_s_channel,
112+
q_w.shape,
113+
num_bits,
114+
group_size,
115+
)
116+
117+
assert torch.equal(
118+
q_w, unpacked_q_w
119+
), "Unpacked weights do not match original weights"
120+
assert torch.equal(
121+
s_channel, unpacked_s_channel
122+
), "Unpacked s_channel do not match original s_channel"
123+
assert torch.equal(
124+
s_group, unpacked_s_group
125+
), "Unpacked s_group do not match original s_group"
126+
127+
128+
if __name__ == "__main__":
129+
run_tests()

0 commit comments

Comments
 (0)