99import unittest
1010
1111import torch
12- from parameterized import parameterized
1312
14- from torchao .float8 .float8_utils import EPS as float8_eps
1513from torchao .quantization .quant_primitives import (
1614 MappingType ,
1715 ZeroPointDomain ,
1816 choose_qparams_affine ,
19- choose_qparams_affine_float8 ,
17+ choose_qparams_affine_tinygemm ,
2018 dequantize_affine ,
21- dequantize_affine_float8 ,
2219 fake_quantize_affine ,
2320 fake_quantize_affine_cachemask ,
2421 quantize_affine ,
25- quantize_affine_float8 ,
2622)
2723
2824# TODO: remove test for utils?
@@ -650,35 +646,6 @@ def test_raises(self):
650646 with self .assertRaisesRegex (RuntimeError , "is invalid for input of size 1" ):
651647 _ = quantize_affine (input , block_size , scale , zero_point , dtype )
652648
653- def test_not_preserve_zero_not_supported (self ):
654- """Making sure preserve_zero == False is not supported for symmetric quant"""
655- input = torch .randn (10 , 256 )
656- n_bit = 4
657- mapping_type = MappingType .SYMMETRIC
658- dtype = torch .int8
659- block_size = (1 , 128 )
660- quant_min = 0
661- quant_max = 2 ** n_bit - 1
662- eps = 1e-6
663- scale_dtype = torch .bfloat16
664- zero_point_dtype = torch .bfloat16
665- with self .assertRaisesRegex (
666- ValueError ,
667- "preserve_zero == False is not supported for symmetric quantization" ,
668- ):
669- choose_qparams_affine (
670- input ,
671- mapping_type ,
672- block_size ,
673- dtype ,
674- quant_min ,
675- quant_max ,
676- eps ,
677- scale_dtype = scale_dtype ,
678- zero_point_dtype = zero_point_dtype ,
679- preserve_zero = False ,
680- )
681-
682649 def test_get_groupwise_affine_qparams (self ):
683650 input = torch .randn (10 , 256 )
684651 n_bit = 4
@@ -702,22 +669,33 @@ def test_get_groupwise_affine_qparams(self):
702669 dtype = torch .bfloat16 ,
703670 zero_point_domain = zero_point_domain ,
704671 )
705- scale , zero_point = choose_qparams_affine (
706- input ,
707- mapping_type ,
708- block_size ,
709- dtype ,
710- quant_min ,
711- quant_max ,
712- eps ,
713- scale_dtype = scale_dtype ,
714- zero_point_dtype = zero_point_dtype ,
715- preserve_zero = zero_point_domain == ZeroPointDomain .INT ,
716- zero_point_domain = zero_point_domain ,
717- )
672+ if zero_point_domain == ZeroPointDomain .FLOAT :
673+ scale , zero_point = choose_qparams_affine_tinygemm (
674+ input ,
675+ mapping_type ,
676+ block_size ,
677+ dtype ,
678+ quant_min ,
679+ quant_max ,
680+ eps ,
681+ scale_dtype = scale_dtype ,
682+ zero_point_dtype = zero_point_dtype ,
683+ )
684+ else :
685+ scale , zero_point = choose_qparams_affine (
686+ input ,
687+ mapping_type ,
688+ block_size ,
689+ dtype ,
690+ quant_min ,
691+ quant_max ,
692+ eps ,
693+ scale_dtype = scale_dtype ,
694+ zero_point_dtype = zero_point_dtype ,
695+ )
718696
719- self .assertTrue (torch .equal (scale , scale_ref ))
720- self .assertTrue (torch .equal (zero_point , zero_point_ref ))
697+ self .assertTrue (torch .equal (scale , scale_ref ))
698+ self .assertTrue (torch .equal (zero_point , zero_point_ref ))
721699
722700 def test_groupwise_affine_quantize_tensor_from_qparams (self ):
723701 input = torch .randn (10 , 256 )
@@ -847,120 +825,6 @@ def test_fake_quantize_affine_cachemask(self):
847825 torch .testing .assert_close (dequantized , fake_quantized )
848826 torch .testing .assert_close (expected_mask , mask )
849827
850- def test_none_zero_point_domain (self ):
851- """A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should"""
852- input = torch .randn (10 , 256 )
853- mapping_type = MappingType .SYMMETRIC
854- dtype = torch .int8
855- block_size = (1 , 128 )
856- quant_min = None
857- quant_max = None
858- eps = 1e-6
859- scale_dtype = torch .float32
860- zero_point_dtype = torch .int64
861- try :
862- _ , zero_point = choose_qparams_affine (
863- input ,
864- mapping_type ,
865- block_size ,
866- dtype ,
867- quant_min ,
868- quant_max ,
869- eps ,
870- scale_dtype = scale_dtype ,
871- zero_point_dtype = zero_point_dtype ,
872- preserve_zero = True ,
873- zero_point_domain = None ,
874- )
875- except ValueError :
876- # This exception was expected
877- # Now test for ZeroPointDomain.NONE
878- _ , zero_point = choose_qparams_affine (
879- input ,
880- mapping_type ,
881- block_size ,
882- dtype ,
883- quant_min ,
884- quant_max ,
885- eps ,
886- scale_dtype = scale_dtype ,
887- zero_point_dtype = zero_point_dtype ,
888- preserve_zero = True ,
889- zero_point_domain = ZeroPointDomain .NONE ,
890- )
891- self .assertTrue (zero_point is None )
892- else :
893- # An exception should have been thrown for zero_point_domain None
894- self .assertTrue (
895- False ,
896- msg = "A runtime exception should have been thrown for zero_point_domain None" ,
897- )
898-
899- @parameterized .expand (
900- [
901- (
902- torch .float32 ,
903- torch .float8_e4m3fn ,
904- ),
905- (
906- torch .float32 ,
907- torch .float8_e5m2 ,
908- ),
909- (
910- torch .bfloat16 ,
911- torch .float8_e4m3fn ,
912- ),
913- (
914- torch .bfloat16 ,
915- torch .float8_e5m2 ,
916- ),
917- ]
918- )
919- def test_float8_quant_primitives (self , hp_dtype , float8_dtype ):
920- input = torch .randn (10 , 10 )
921-
922- # float8 quantization primitives
923- scale = choose_qparams_affine_float8 (input , float8_dtype = float8_dtype )
924- quantized = quantize_affine_float8 (input , scale , float8_dtype = float8_dtype )
925- dequantized = dequantize_affine_float8 (quantized , scale , output_dtype = hp_dtype )
926-
927- # reference implementation using generic primitives
928- expected_scale , _ = choose_qparams_affine (
929- input ,
930- MappingType .SYMMETRIC ,
931- input .shape ,
932- float8_dtype ,
933- eps = float8_eps , # use same EPS as float8 training
934- scale_dtype = torch .float32 ,
935- quant_min = torch .finfo (float8_dtype ).min ,
936- quant_max = torch .finfo (float8_dtype ).max ,
937- )
938- expected_quantized = quantize_affine (
939- input ,
940- input .shape ,
941- scale ,
942- output_dtype = float8_dtype ,
943- quant_min = torch .finfo (float8_dtype ).min ,
944- quant_max = torch .finfo (float8_dtype ).max ,
945- zero_point = None ,
946- zero_point_domain = ZeroPointDomain .NONE ,
947- )
948- expected_dequantized = dequantize_affine (
949- expected_quantized ,
950- input .shape ,
951- scale ,
952- input_dtype = float8_dtype ,
953- output_dtype = hp_dtype ,
954- quant_min = torch .finfo (float8_dtype ).min ,
955- quant_max = torch .finfo (float8_dtype ).max ,
956- zero_point = None ,
957- zero_point_domain = ZeroPointDomain .NONE ,
958- )
959-
960- self .assertTrue (torch .equal (expected_scale , scale ))
961- torch .testing .assert_close (expected_quantized , quantized )
962- torch .testing .assert_close (expected_dequantized , dequantized )
963-
964828
965829if __name__ == "__main__" :
966830 unittest .main ()
0 commit comments