@@ -479,12 +479,19 @@ def from_float(cls, weight):
479479
480480class  AQFloat8WeightOnlyQuantizedLinearWeight (AffineQuantizedTensor , AQMixin ):
481481    """ 
482-     AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight 
482+     AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn  
483483    """ 
484+     target_dtype : torch .dtype  =  torch .float8_e4m3fn 
485+ 
486+     @staticmethod  
487+     def  _quantized_linear_op (act_mat , w_qtensor , bias ):
488+         return  torch .nn .functional .linear (act_mat , w_qtensor .dequantize (), bias )
489+ 
484490    @classmethod  
485491    def  from_float (cls , weight ):
486492        block_size  =  (1 , weight .shape [1 ])
487-         return  super (AQFloat8WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_floatx (weight , block_size , target_dtype = torch .float8_e4m3fn , layout_type = Float8LayoutType ())
493+         return  super (AQFloat8WeightOnlyQuantizedLinearWeight , cls ).from_hp_to_floatx (weight , block_size , target_dtype = cls .target_dtype , layout_type = Float8LayoutType ())
494+ 
488495
489496# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison 
490497DEFAULT_AUTOQUANT_CLASS_LIST  =  [
@@ -500,7 +507,7 @@ def from_float(cls, weight):
500507DEFAULT_INT4_AUTOQUANT_CLASS_LIST  =  [
501508    AQFloatLinearWeight ,
502509    AQInt8DynamicallyQuantizedLinearWeight ,
503-     AQInt4G64WeightOnlyQuantizedLinearWeight , 
510+     AQInt4G64WeightOnlyQuantizedLinearWeight 
504511]
505512
506513def  _change_linears_to_autoquantizable (model , ** kwargs ):
0 commit comments