File tree Expand file tree Collapse file tree 2 files changed +29
-3
lines changed 
model_executor/layers/quantization/utils Expand file tree Collapse file tree 2 files changed +29
-3
lines changed Original file line number Diff line number Diff line change @@ -559,15 +559,15 @@ def cutlass_scaled_mm(a: torch.Tensor,
559559        scale_a.shape * [1, 128] == a.shape 
560560        scale_b.shape * [128, 128] == b.shape 
561561    """ 
562-     assert  (b .shape [0 ] %  16  ==  0  and  b .shape [1 ] %  16  ==  0 )
563562    assert  (out_dtype  is  torch .bfloat16  or  out_dtype  is  torch .float16 )
564563    assert  bias  is  None  or  bias .shape [0 ] ==  b .shape [
565564        1 ] and  bias .dtype  ==  out_dtype 
566565
567566    m  =  a .shape [0 ]
568567    n  =  b .shape [1 ]
569568
570-     if  current_platform .is_rocm ():
569+     cutlass_compatible_b  =  (b .shape [0 ] %  16  ==  0  and  b .shape [1 ] %  16  ==  0 )
570+     if  current_platform .is_rocm () or  not  cutlass_compatible_b :
571571        triton_scaled_mm_module  =  importlib .import_module (
572572            "vllm.model_executor.layers.quantization.compressed_tensors." 
573573            "triton_scaled_mm" )
Original file line number Diff line number Diff line change @@ -85,6 +85,32 @@ def block_dequant(
8585    return  x_dq_block 
8686
8787
88+ if  current_platform .is_rocm ():
89+     from  triton .language  import  core 
90+ 
91+     # NOTE: This can be removed when hip.libdevice.round() is available. 
92+     @core .extern  
93+     def  round_f32 (arg0 , _builder = None ):
94+         return  core .extern_elementwise ("" ,
95+                                        "" , [arg0 ], {
96+                                            (core .dtype ("fp32" ), ):
97+                                            ("llvm.round" , core .dtype ("fp32" )),
98+                                            (core .dtype ("fp64" ), ):
99+                                            ("llvm.round" , core .dtype ("fp64" )),
100+                                        },
101+                                        is_pure = True ,
102+                                        _builder = _builder )
103+ 
104+     @triton .jit  
105+     def  round_int8 (x ):
106+         return  round_f32 (x ).to (tl .int8 )
107+ else :
108+ 
109+     @triton .jit  
110+     def  round_int8 (x ):
111+         return  tl .extra .cuda .libdevice .round (x ).to (tl .int8 )
112+ 
113+ 
88114@triton .jit  
89115def  _per_token_quant_int8 (
90116    x_ptr ,
@@ -106,7 +132,7 @@ def _per_token_quant_int8(
106132    absmax  =  tl .maximum (tl .max (tl .abs (x )), 1e-10 )
107133    scale_x  =  absmax  /  127 
108134    x_q  =  x  *  (127  /  absmax )
109-     x_q  =  tl . extra . cuda . libdevice . round (x_q ). to ( tl . int8 )
135+     x_q  =  round_int8 (x_q )
110136
111137    tl .store (xq_ptr  +  row_id  *  stride_xq  +  cols , x_q , mask = mask )
112138    tl .store (scale_ptr  +  row_id , scale_x )
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments