Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 24 additions & 35 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,55 +16,44 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

using GroupShape = std::array<int64_t, 2>;

int M = a.size(0), N = b.size(1), K = a.size(1);

GroupShape a_scale_group_shape = [&, &s = a_scales]() -> GroupShape {
if (s.numel() == 1) return {M, K}; // tensor-wise
if (s.dim() == 2)
return {ceil_div(a.size(0), s.size(0)), ceil_div(a.size(1), s.size(1))};
TORCH_CHECK(false, "Unsupported scale shape for scale_a");
}();

GroupShape b_scale_group_shape = [&, &s = b_scales]() -> GroupShape {
if (s.numel() == 1) return {K, N}; // tensor-wise
if (s.dim() == 2)
return {ceil_div(b.size(0), s.size(0)), ceil_div(b.size(1), s.size(1))};
TORCH_CHECK(false, "Unsupported scale shape for scale_b");
}();

if ((a_scale_group_shape == GroupShape{M, K} ||
a_scale_group_shape == GroupShape{1, K}) &&
(b_scale_group_shape == GroupShape{K, N} ||
b_scale_group_shape == GroupShape{K, 1})) {
// "standard per-tensor/per-token/per-channel" scaling
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
Comment on lines -37 to +22
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change

// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
}
} else if (a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128}) {
} else {
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))};
};

GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);

// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn,
"Currently only FP8 is supported for A group shape 1x128 and "
"B group shape 128x128");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");

vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(false,
"Unsupported scale group shapes for CUTLASS 3.x GEMM.\n "
"a_scale_group_shape must be [1, 128], got: [",
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128], got: [",
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");

vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
}
}

Expand Down
Loading