44import  functools 
55import  json 
66import  os 
7- from  typing  import  Any , Optional , Union 
7+ from  typing  import  Any , Callable ,  Optional , Union 
88
99import  torch 
1010
@@ -27,6 +27,76 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
2727    return  x  ==  torch .float8_e4m3fn  or  x  ==  torch .float8_e4m3fnuz 
2828
2929
30+ def  cutlass_scaled_mm (
31+     A : torch .Tensor ,
32+     B : torch .Tensor ,
33+     As : torch .Tensor ,
34+     Bs : torch .Tensor ,
35+     block_size : list [int ],
36+     output_dtype : torch .dtype  =  torch .float16 ,
37+ ) ->  torch .Tensor :
38+     return  ops .cutlass_scaled_mm (A ,
39+                                  B .T ,
40+                                  out_dtype = output_dtype ,
41+                                  scale_a = As ,
42+                                  scale_b = Bs .T )
43+ 
44+ 
45+ def  rocm_aiter_gemm_w8a8_blockscale_impl (
46+     A : torch .Tensor ,
47+     B : torch .Tensor ,
48+     As : torch .Tensor ,
49+     Bs : torch .Tensor ,
50+     block_size : list [int ],
51+     output_dtype : torch .dtype  =  torch .float16 ,
52+ ) ->  torch .Tensor :
53+     import  aiter  as  rocm_aiter 
54+ 
55+     return  rocm_aiter .gemm_a8w8_blockscale_CK (A , B , As , Bs , dtype = output_dtype )
56+ 
57+ 
58+ def  rocm_aiter_gemm_w8a8_blockscale_fake (
59+     A : torch .Tensor ,
60+     B : torch .Tensor ,
61+     As : torch .Tensor ,
62+     Bs : torch .Tensor ,
63+     block_size : list [int ],
64+     output_dtype : torch .dtype  =  torch .float16 ,
65+ ) ->  torch .Tensor :
66+ 
67+     m  =  A .shape [0 ]
68+     n  =  B .shape [0 ]
69+     Y  =  torch .empty (m , n , dtype = output_dtype , device = A .device )
70+     return  Y 
71+ 
72+ 
73+ if  current_platform .is_rocm ():
74+     direct_register_custom_op (
75+         op_name = "rocm_aiter_gemm_w8a8_blockscale" ,
76+         op_func = rocm_aiter_gemm_w8a8_blockscale_impl ,
77+         mutates_args = [],
78+         fake_impl = rocm_aiter_gemm_w8a8_blockscale_fake ,
79+         dispatch_key = current_platform .dispatch_key ,
80+     )
81+ 
82+ 
83+ def  dispatch_w8a8_blockscale_func (
84+     use_cutlass : bool , use_aiter_and_is_supported : bool 
85+ ) ->  Callable [[
86+         torch .Tensor ,
87+         torch .Tensor ,
88+         torch .Tensor ,
89+         torch .Tensor ,
90+         list [int ],
91+         torch .dtype ,
92+ ], torch .Tensor ]:
93+     if  use_cutlass :
94+         return  cutlass_scaled_mm 
95+     if  (use_aiter_and_is_supported ):
96+         return  torch .ops .vllm .rocm_aiter_gemm_w8a8_blockscale 
97+     return  w8a8_block_fp8_matmul 
98+ 
99+ 
30100# TODO fix ROCm->Triton custom path: 
31101#  https://github.com/vllm-project/vllm/issues/14397 
32102def  apply_w8a8_block_fp8_linear (
@@ -37,26 +107,23 @@ def apply_w8a8_block_fp8_linear(
37107    input_scale : Optional [torch .Tensor ] =  None ,
38108    bias : Optional [torch .Tensor ] =  None ,
39109    cutlass_block_fp8_supported : bool  =  CUTLASS_BLOCK_FP8_SUPPORTED ,
110+     use_aiter_and_is_supported : bool  =  False ,
40111) ->  torch .Tensor :
41112    assert  input_scale  is  None 
42113    # View input as 2D matrix for fp8 methods 
43114    input_2d  =  input .view (- 1 , input .shape [- 1 ])
44115    output_shape  =  [* input .shape [:- 1 ], weight .shape [0 ]]
45116
46-     shape_supported_by_cutlass  =  (weight .shape [0 ] %  128  ==  0 
47-                                   and  weight .shape [1 ] %  128  ==  0 )
48-     if  current_platform .is_rocm ():
49-         # TODO this is never used, as cutlass_block_fp8_supported is False 
50-         scale_a_shape  =  ((input_2d .shape [- 1 ] //  block_size [1 ], ) + 
51-                          input_2d .shape [:- 1 ])[::- 1 ]
52-         scale_b_shape  =  (weight_scale .view (- 1 , 1 )
53-                          if  weight_scale .dim () <=  1  else  weight_scale .T ).shape 
54-         ar , ac  =  scale_a_shape 
55-         br , bc  =  scale_b_shape 
56-         if  (ac  >  1  or  bc  >  1  or  ar  not  in 1 , input_2d .shape [0 ])
57-                 or  br  not  in 1 , weight .shape [0 ])):
58-             shape_supported_by_cutlass  =  False 
59-     if  cutlass_block_fp8_supported  and  shape_supported_by_cutlass :
117+     if  current_platform .is_cuda ():
118+         use_cutlass  =  cutlass_block_fp8_supported  and  (
119+             weight .shape [0 ] %  128  ==  0  and  weight .shape [1 ] %  128  ==  0 )
120+     else :
121+         use_cutlass  =  False 
122+ 
123+     w8a8_blockscale_func  =  dispatch_w8a8_blockscale_func (
124+         use_cutlass , use_aiter_and_is_supported )
125+ 
126+     if  use_cutlass :
60127        rows , cols  =  input_2d .shape 
61128        # Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for 
62129        # optimal tensor core usage. Can be removed when targeting platforms 
@@ -67,26 +134,22 @@ def apply_w8a8_block_fp8_linear(
67134            input_2d  =  torch .nn .functional .pad (input_2d ,
68135                                               (0 , 0 , 0 , 4  -  (rows  %  4 )),
69136                                               value = 0 ).contiguous ()
70-         q_input , x_scale  =  per_token_group_quant_fp8 (input_2d ,
71-                                                      block_size [1 ],
72-                                                      column_major_scales = True )
73-         output  =  ops .cutlass_scaled_mm (q_input ,
74-                                        weight .T ,
75-                                        out_dtype = input .dtype ,
76-                                        scale_a = x_scale ,
77-                                        scale_b = weight_scale .T )
137+ 
138+         q_input , x_scale  =  per_token_group_quant_fp8 (
139+             input_2d , block_size [1 ], column_major_scales = use_cutlass )
140+ 
141+         output  =  w8a8_blockscale_func (q_input , weight , x_scale , weight_scale ,
142+                                       block_size , input .dtype )
78143        if  should_pad :
79144            output  =  output [:rows , :]
145+ 
80146    else :
81-         q_input , x_scale  =  per_token_group_quant_fp8 (input_2d ,
82-                                                      block_size [1 ],
83-                                                      column_major_scales = False )
84-         output  =  w8a8_block_fp8_matmul (q_input ,
85-                                        weight ,
86-                                        x_scale ,
87-                                        weight_scale ,
88-                                        block_size ,
89-                                        output_dtype = input .dtype )
147+         q_input , x_scale  =  per_token_group_quant_fp8 (
148+             input_2d , block_size [1 ], column_major_scales = use_cutlass )
149+ 
150+         output  =  w8a8_blockscale_func (q_input , weight , x_scale , weight_scale ,
151+                                       block_size , input .dtype )
152+ 
90153    if  bias  is  not None :
91154        output  =  output  +  bias 
92155    return  output .to (dtype = input .dtype ).view (* output_shape )
@@ -98,6 +161,9 @@ def apply_w8a8_block_fp8_linear_fake(
98161    block_size : list [int ],
99162    weight_scale : torch .Tensor ,
100163    input_scale : Optional [torch .Tensor ] =  None ,
164+     bias : Optional [torch .Tensor ] =  None ,
165+     cutlass_block_fp8_supported : bool  =  CUTLASS_BLOCK_FP8_SUPPORTED ,
166+     use_aiter_and_is_supported : bool  =  False ,
101167) ->  torch .Tensor :
102168    output_shape  =  [* input .shape [:- 1 ], weight .shape [0 ]]
103169    return  torch .empty (output_shape , dtype = input .dtype , device = input .device )
0 commit comments