From fd8f32fd9b724f981e944115ceacd5db56017f89 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 3 Feb 2025 15:47:45 +0000 Subject: [PATCH] Fix per-token/per-channel quantization for Hopper scaled mm Signed-off-by: Tyler Michael Smith --- .../cutlass_w8a8/scaled_mm_c3x.cu | 49 ++++++++----------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index e6f06d72fbfd..5cb4312be6d5 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -16,29 +16,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - using GroupShape = std::array; - int M = a.size(0), N = b.size(1), K = a.size(1); - GroupShape a_scale_group_shape = [&, &s = a_scales]() -> GroupShape { - if (s.numel() == 1) return {M, K}; // tensor-wise - if (s.dim() == 2) - return {ceil_div(a.size(0), s.size(0)), ceil_div(a.size(1), s.size(1))}; - TORCH_CHECK(false, "Unsupported scale shape for scale_a"); - }(); - - GroupShape b_scale_group_shape = [&, &s = b_scales]() -> GroupShape { - if (s.numel() == 1) return {K, N}; // tensor-wise - if (s.dim() == 2) - return {ceil_div(b.size(0), s.size(0)), ceil_div(b.size(1), s.size(1))}; - TORCH_CHECK(false, "Unsupported scale shape for scale_b"); - }(); - - if ((a_scale_group_shape == GroupShape{M, K} || - a_scale_group_shape == GroupShape{1, K}) && - (b_scale_group_shape == GroupShape{K, N} || - b_scale_group_shape == GroupShape{K, 1})) { - // "standard per-tensor/per-token/per-channel" scaling + if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && + (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { + // Standard per-tensor/per-token/per-channel scaling TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); if (a.dtype() == torch::kFloat8_e4m3fn) { vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias); @@ -46,19 +28,28 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, TORCH_CHECK(a.dtype() == torch::kInt8); vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias); } - } else if (a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128}) { + } else { + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + // 1x128 per-token group scales for activations // 128x128 blockwise scales for weights - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn, - "Currently only FP8 is supported for A group shape 1x128 and " - "B group shape 128x128"); + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn and " + "group shapes 1x128 for A and 128x128 for B"); TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); - } else { - TORCH_CHECK(false, "Unsupported scale group shapes for CUTLASS 3.x GEMM"); } }