Skip to content

Commit c11de33

Browse files
authored
[Bugfix][Kernel] Fix per-token/per-channel quantization for Hopper scaled mm (#12696)
Signed-off-by: Tyler Michael Smith <[email protected]>
1 parent 33e0602 commit c11de33

File tree

1 file changed

+24
-35
lines changed

1 file changed

+24
-35
lines changed

csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,55 +16,44 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
1616
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
1717
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
1818

19-
using GroupShape = std::array<int64_t, 2>;
20-
2119
int M = a.size(0), N = b.size(1), K = a.size(1);
2220

23-
GroupShape a_scale_group_shape = [&, &s = a_scales]() -> GroupShape {
24-
if (s.numel() == 1) return {M, K}; // tensor-wise
25-
if (s.dim() == 2)
26-
return {ceil_div(a.size(0), s.size(0)), ceil_div(a.size(1), s.size(1))};
27-
TORCH_CHECK(false, "Unsupported scale shape for scale_a");
28-
}();
29-
30-
GroupShape b_scale_group_shape = [&, &s = b_scales]() -> GroupShape {
31-
if (s.numel() == 1) return {K, N}; // tensor-wise
32-
if (s.dim() == 2)
33-
return {ceil_div(b.size(0), s.size(0)), ceil_div(b.size(1), s.size(1))};
34-
TORCH_CHECK(false, "Unsupported scale shape for scale_b");
35-
}();
36-
37-
if ((a_scale_group_shape == GroupShape{M, K} ||
38-
a_scale_group_shape == GroupShape{1, K}) &&
39-
(b_scale_group_shape == GroupShape{K, N} ||
40-
b_scale_group_shape == GroupShape{K, 1})) {
41-
// "standard per-tensor/per-token/per-channel" scaling
21+
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
22+
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
23+
// Standard per-tensor/per-token/per-channel scaling
4224
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
4325
if (a.dtype() == torch::kFloat8_e4m3fn) {
4426
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
4527
} else {
4628
TORCH_CHECK(a.dtype() == torch::kInt8);
4729
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
4830
}
49-
} else if (a_scale_group_shape == GroupShape{1, 128} &&
50-
b_scale_group_shape == GroupShape{128, 128}) {
31+
} else {
32+
using GroupShape = std::array<int64_t, 2>;
33+
auto make_group_shape = [](torch::Tensor const& x,
34+
torch::Tensor const& s) -> GroupShape {
35+
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
36+
return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))};
37+
};
38+
39+
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
40+
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
41+
5142
// 1x128 per-token group scales for activations
5243
// 128x128 blockwise scales for weights
53-
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn &&
54-
b.dtype() == torch::kFloat8_e4m3fn,
55-
"Currently only FP8 is supported for A group shape 1x128 and "
56-
"B group shape 128x128");
57-
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
58-
59-
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
60-
} else {
61-
TORCH_CHECK(false,
62-
"Unsupported scale group shapes for CUTLASS 3.x GEMM.\n "
63-
"a_scale_group_shape must be [1, 128], got: [",
44+
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
45+
b_scale_group_shape == GroupShape{128, 128} &&
46+
a.dtype() == torch::kFloat8_e4m3fn &&
47+
b.dtype() == torch::kFloat8_e4m3fn),
48+
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
49+
"a_scale_group_shape must be [1, 128]. Got: [",
6450
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
6551
"]\n"
66-
"b_scale_group_shape must be [128, 128], got: [",
52+
"b_scale_group_shape must be [128, 128]. Got: [",
6753
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
54+
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
55+
56+
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
6857
}
6958
}
7059

0 commit comments

Comments
 (0)