3030from torchao .utils import (
3131 TORCH_VERSION_AT_LEAST_2_3 ,
3232 TORCH_VERSION_AT_LEAST_2_5 ,
33- benchmark_model ,
33+ TorchAOBaseTensor ,
3434)
3535
3636from torchao .quantization .granularity import (
6161 "autoquant_v2" ,
6262 "DEFAULT_AUTOQUANT_CLASS_LIST" ,
6363 "DEFAULT_INT4_AUTOQUANT_CLASS_LIST" ,
64+ "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST" ,
6465 "OTHER_AUTOQUANT_CLASS_LIST" ,
6566 "_is_linear" ,
6667]
@@ -288,7 +289,7 @@ def to_quantized(self, error_on_unseen, **kwargs):
288289 )
289290 elif (self .logged_data == {}) and not error_on_unseen :
290291 # default back to non-quantized weight if not seen
291- self = AQFloatLinearWeight .from_float (self .weight )
292+ self = AQDefaultLinearWeight .from_float (self .weight )
292293 return self
293294
294295 # only want to print shape (at start) and final result (at end)
@@ -360,7 +361,7 @@ def count_shapes(self, do_print=True):
360361 print (f"best_cls={ best_cls } \n " )
361362 # TODO handle random cls args/kwargs? or should they be curried?
362363 if best_cls is None :
363- best_cls = AQFloatLinearWeight
364+ best_cls = AQDefaultLinearWeight
364365
365366 self = best_cls .from_float (self .weight )
366367 return self
@@ -802,7 +803,7 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight(
802803 group_size : int = 256
803804
804805
805- class AQFloatLinearWeight (torch .Tensor , AQMixin ):
806+ class AQDefaultLinearWeight (torch .Tensor , AQMixin ):
806807 """
807808 A class to be used in concert with AutoQuantizableLinearWeight to provide a
808809 default/non-quantized option. Only implements the bare minimum needed to work with the
@@ -823,6 +824,130 @@ def from_float(cls, weight):
823824 return weight
824825
825826
827+ class Float32Tensor (TorchAOBaseTensor ):
828+ """ Tensor subclass tensor for fp32 dtype
829+ """
830+ def __init__ (self , weight ):
831+ self .weight = weight .to (torch .float32 )
832+
833+ @staticmethod
834+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
835+ _DTYPE = torch .float32
836+ orig_dtype = act_mat .dtype
837+ return torch .nn .functional .linear (
838+ act_mat .to (_DTYPE ),
839+ w_qtensor .weight ,
840+ bias .to (_DTYPE ) if bias is not None else bias ,
841+ ).to (dtype = orig_dtype )
842+
843+ def _apply_fn_to_data (self , fn ):
844+ return self .__class__ (
845+ fn (self .weight ),
846+ )
847+
848+ @classmethod
849+ def from_float (cls , weight ):
850+ return cls (weight )
851+
852+ @Float32Tensor .implements ([torch .nn .functional .linear , aten .linear .default ])
853+ def _ (func , types , args , kwargs ):
854+ input_tensor , weight_tensor , bias = (
855+ args [0 ],
856+ args [1 ],
857+ args [2 ] if len (args ) > 2 else None ,
858+ )
859+ return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
860+
861+ @Float32Tensor .implements (aten .detach .default )
862+ def _ (func , types , args , kwargs ):
863+ return return_and_correct_aliasing (
864+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
865+ )
866+
867+
868+ @Float32Tensor .implements (aten .clone .default )
869+ def _ (func , types , args , kwargs ):
870+ return return_and_correct_aliasing (
871+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
872+ )
873+
874+
875+ @Float32Tensor .implements (aten ._to_copy .default )
876+ def _ (func , types , args , kwargs ):
877+ return return_and_correct_aliasing (
878+ func ,
879+ args ,
880+ kwargs ,
881+ args [0 ].to (* args [1 :], ** kwargs )._apply_fn_to_data (torch .clone ),
882+ )
883+
884+
885+ class BFloat16Tensor (Float32Tensor ):
886+ def __init__ (self , weight ):
887+ self .weight = weight .to (torch .bfloat16 )
888+
889+ @staticmethod
890+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
891+ _DTYPE = torch .bfloat16
892+ orig_dtype = act_mat .dtype
893+ return torch .nn .functional .linear (
894+ act_mat .to (_DTYPE ),
895+ w_qtensor .weight ,
896+ bias .to (_DTYPE ) if bias is not None else bias ,
897+ ).to (dtype = orig_dtype )
898+
899+
900+ class Float16Tensor (Float32Tensor ):
901+ def __init__ (self , weight ):
902+ self .weight = weight .to (torch .float16 )
903+
904+ @staticmethod
905+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
906+ _DTYPE = torch .float16
907+ orig_dtype = act_mat .dtype
908+ return torch .nn .functional .linear (
909+ act_mat .to (_DTYPE ),
910+ w_qtensor .weight ,
911+ bias .to (_DTYPE ) if bias is not None else bias ,
912+ ).to (dtype = orig_dtype )
913+
914+
915+ class AQFloat32LinearWeight (Float32Tensor , AQMixin ):
916+ """
917+ AutoQuantizable version for float32 precision weight
918+
919+ (also converts input activation and bias to float32, and restores the original precision after
920+ linear)
921+ """
922+ @classmethod
923+ def from_float (cls , weight ):
924+ return super (AQFloat32LinearWeight , cls ).from_float (weight )
925+
926+
927+ class AQBFloat16LinearWeight (BFloat16Tensor , AQMixin ):
928+ """
929+ AutoQuantizable version for bfloat16 precision weight
930+
931+ (also converts input activation and bias to bfloat16, and restores the original precision after
932+ linear)
933+ """
934+ @classmethod
935+ def from_float (cls , weight ):
936+ return super (AQBFloat16LinearWeight , cls ).from_float (weight )
937+
938+
939+ class AQFloat16LinearWeight (Float16Tensor , AQMixin ):
940+ """
941+ AutoQuantizable version for float16 precision weight
942+
943+ (also converts input activation and bias to float16, and restores the original precision after
944+ linear)
945+ """
946+ @classmethod
947+ def from_float (cls , weight ):
948+ return super (AQFloat16LinearWeight , cls ).from_float (weight )
949+
950+
826951class AQFloat8WeightOnlyQuantizedLinearWeight (AffineQuantizedTensor , AQMixin ):
827952 """
828953 AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
@@ -936,7 +1061,7 @@ def get_weight_block_size(x):
9361061
9371062# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
9381063DEFAULT_AUTOQUANT_CLASS_LIST = [
939- AQFloatLinearWeight ,
1064+ AQDefaultLinearWeight ,
9401065 AQInt8WeightOnlyQuantizedLinearWeight ,
9411066 AQInt8WeightOnlyQuantizedLinearWeight2 ,
9421067 # AQInt8WeightOnlyQuantizedLinearWeight3,
@@ -945,11 +1070,17 @@ def get_weight_block_size(x):
9451070]
9461071
9471072DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
948- AQFloatLinearWeight ,
1073+ AQDefaultLinearWeight ,
9491074 AQInt8DynamicallyQuantizedLinearWeight ,
9501075 AQInt4G64WeightOnlyQuantizedLinearWeight ,
9511076]
9521077
1078+ DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [
1079+ AQFloat32LinearWeight ,
1080+ AQBFloat16LinearWeight ,
1081+ AQFloat16LinearWeight ,
1082+ ]
1083+
9531084OTHER_AUTOQUANT_CLASS_LIST = [
9541085 AQFloat8WeightOnlyQuantizedLinearWeight ,
9551086 AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight ,
0 commit comments