@@ -16,55 +16,44 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
1616  TORCH_CHECK (a_scales.dtype () == torch::kFloat32 );
1717  TORCH_CHECK (b_scales.dtype () == torch::kFloat32 );
1818
19-   using  GroupShape = std::array<int64_t , 2 >;
20- 
2119  int  M = a.size (0 ), N = b.size (1 ), K = a.size (1 );
2220
23-   GroupShape a_scale_group_shape = [&, &s = a_scales]() -> GroupShape {
24-     if  (s.numel () == 1 ) return  {M, K};  //  tensor-wise
25-     if  (s.dim () == 2 )
26-       return  {ceil_div (a.size (0 ), s.size (0 )), ceil_div (a.size (1 ), s.size (1 ))};
27-     TORCH_CHECK (false , " Unsupported scale shape for scale_a" 
28-   }();
29- 
30-   GroupShape b_scale_group_shape = [&, &s = b_scales]() -> GroupShape {
31-     if  (s.numel () == 1 ) return  {K, N};  //  tensor-wise
32-     if  (s.dim () == 2 )
33-       return  {ceil_div (b.size (0 ), s.size (0 )), ceil_div (b.size (1 ), s.size (1 ))};
34-     TORCH_CHECK (false , " Unsupported scale shape for scale_b" 
35-   }();
36- 
37-   if  ((a_scale_group_shape == GroupShape{M, K} ||
38-        a_scale_group_shape == GroupShape{1 , K}) &&
39-       (b_scale_group_shape == GroupShape{K, N} ||
40-        b_scale_group_shape == GroupShape{K, 1 })) {
41-     //  "standard per-tensor/per-token/per-channel" scaling
21+   if  ((a_scales.numel () == 1  || a_scales.numel () == a.size (0 )) &&
22+       (b_scales.numel () == 1  || b_scales.numel () == b.size (1 ))) {
23+     //  Standard per-tensor/per-token/per-channel scaling
4224    TORCH_CHECK (a_scales.is_contiguous () && b_scales.is_contiguous ());
4325    if  (a.dtype () == torch::kFloat8_e4m3fn ) {
4426      vllm::cutlass_scaled_mm_sm90_fp8 (c, a, b, a_scales, b_scales, bias);
4527    } else  {
4628      TORCH_CHECK (a.dtype () == torch::kInt8 );
4729      vllm::cutlass_scaled_mm_sm90_int8 (c, a, b, a_scales, b_scales, bias);
4830    }
49-   } else  if  (a_scale_group_shape == GroupShape{1 , 128 } &&
50-              b_scale_group_shape == GroupShape{128 , 128 }) {
31+   } else  {
32+     using  GroupShape = std::array<int64_t , 2 >;
33+     auto  make_group_shape = [](torch::Tensor const & x,
34+                                torch::Tensor const & s) -> GroupShape {
35+       TORCH_CHECK (s.dim () == 2 , " cutlass_scaled_mm group scales must be 2D" 
36+       return  {ceil_div (x.size (0 ), s.size (0 )), ceil_div (x.size (1 ), s.size (1 ))};
37+     };
38+ 
39+     GroupShape a_scale_group_shape = make_group_shape (a, a_scales);
40+     GroupShape b_scale_group_shape = make_group_shape (b, b_scales);
41+ 
5142    //  1x128 per-token group scales for activations
5243    //  128x128 blockwise scales for weights
53-     TORCH_CHECK (a.dtype () == torch::kFloat8_e4m3fn  &&
54-                     b.dtype () == torch::kFloat8_e4m3fn ,
55-                 " Currently only FP8 is supported for A group shape 1x128 and " 
56-                 " B group shape 128x128" 
57-     TORCH_CHECK (!bias, " Bias not yet supported blockwise scaled_mm" 
58- 
59-     vllm::cutlass_scaled_mm_blockwise_sm90_fp8 (c, a, b, a_scales, b_scales);
60-   } else  {
61-     TORCH_CHECK (false ,
62-                 " Unsupported scale group shapes for CUTLASS 3.x GEMM.\n  " 
63-                 " a_scale_group_shape must be [1, 128], got: [" 
44+     TORCH_CHECK ((a_scale_group_shape == GroupShape{1 , 128 } &&
45+                  b_scale_group_shape == GroupShape{128 , 128 } &&
46+                  a.dtype () == torch::kFloat8_e4m3fn  &&
47+                  b.dtype () == torch::kFloat8_e4m3fn ),
48+                 " cutlass_scaled_mm only supports datatype float8_e4m3fn.\n " 
49+                 " a_scale_group_shape must be [1, 128]. Got: [" 
6450                a_scale_group_shape[0 ], " , " 1 ],
6551                " ]\n " 
66-                 " b_scale_group_shape must be [128, 128], got : [" 
52+                 " b_scale_group_shape must be [128, 128]. Got : [" 
6753                b_scale_group_shape[0 ], " , " 1 ], " ]" 
54+     TORCH_CHECK (!bias, " Bias not yet supported blockwise scaled_mm" 
55+ 
56+     vllm::cutlass_scaled_mm_blockwise_sm90_fp8 (c, a, b, a_scales, b_scales);
6857  }
6958}
7059
0 commit comments