2020 DynamicallyPerAxisQuantizedLinear ,
2121)
2222from torchao .quantization .quant_api import (
23- apply_dynamic_quant ,
24- apply_weight_only_int8_quant ,
25- change_linear_weights_to_int8_dqtensors ,
26- change_linear_weights_to_int8_woqtensors ,
27- change_linear_weights_to_int4_woqtensors ,
23+ get_apply_int4wo_quant ,
24+ get_apply_int8wo_quant ,
25+ get_apply_int8dyn_quant ,
26+ quantize ,
2827 _replace_with_custom_fn_if_matches_filter ,
2928)
3029from torchao .quantization .quant_primitives import (
7372from parameterized import parameterized
7473import itertools
7574import logging
76- from torchao .utils import TORCH_VERSION_AFTER_2_3 , TORCH_VERSION_AFTER_2_4
75+ from torchao .utils import (
76+ TORCH_VERSION_AFTER_2_3 ,
77+ TORCH_VERSION_AFTER_2_4 ,
78+ unwrap_tensor_subclass ,
79+ )
7780
7881logger = logging .getLogger ("INFO" )
7982
8285
8386# TODO: use this to reduce the number of tests
8487TENSOR_SUBCLASS_APIS = [
85- change_linear_weights_to_int8_dqtensors ,
86- change_linear_weights_to_int8_woqtensors ,
87- change_linear_weights_to_int4_woqtensors ,
88+ get_apply_int4wo_quant ,
89+ get_apply_int8wo_quant ,
90+ get_apply_int8dyn_quant ,
8891]
8992
9093COMMON_DEVICES = ["cpu" , "cuda" ]
@@ -736,7 +739,8 @@ def _test_lin_weight_subclass_api_impl(
736739 nn .Linear (k , n , device = test_device ), nn .ReLU (), nn .Linear (n , n , device = test_device )
737740 ).to (test_dtype )
738741 ref_f = mod (x )
739- api (mod )
742+ quantize (mod , api ())
743+ unwrap_tensor_subclass (mod )
740744
741745 test = mod (x )
742746 self .assertGreater (
@@ -756,13 +760,13 @@ def _test_lin_weight_subclass_api_impl(
756760 @unittest .skipIf (TORCH_VERSION_AFTER_2_4 , "skip because there is some bug in inductor codegen" )
757761 def test_int8_dynamic_quant_subclass_api (self , device , dtype ):
758762 self ._test_lin_weight_subclass_api_impl (
759- change_linear_weights_to_int8_dqtensors , device , 35 , test_dtype = dtype
763+ get_apply_int8dyn_quant , device , 35 , test_dtype = dtype
760764 )
761765
762766 @parameterized .expand (COMMON_DEVICE_DTYPE )
763767 def test_int8_weight_only_quant_subclass_api (self , device , dtype ):
764768 self ._test_lin_weight_subclass_api_impl (
765- change_linear_weights_to_int8_woqtensors , device , 40 , test_dtype = dtype
769+ get_apply_int8wo_quant , device , 40 , test_dtype = dtype
766770 )
767771
768772 @parameterized .expand (COMMON_DEVICE_DTYPE )
@@ -772,7 +776,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
772776 self .skipTest (f"Fails for { dtype } " )
773777 for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 256 )] if device == 'cuda' else [])):
774778 self ._test_lin_weight_subclass_api_impl (
775- change_linear_weights_to_int4_woqtensors ,
779+ get_apply_int4wo_quant ,
776780 device ,
777781 15 ,
778782 test_shape = test_shape ,
@@ -789,7 +793,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
789793 for inner_k_tiles in [4 , 2 ]:
790794 kwargs = {"groupsize" : groupsize , "inner_k_tiles" : inner_k_tiles }
791795 self ._test_lin_weight_subclass_api_impl (
792- lambda mod : change_linear_weights_to_int4_woqtensors ( mod , ** kwargs ),
796+ lambda : get_apply_int4wo_quant ( ** kwargs ),
793797 device ,
794798 15 ,
795799 test_shape = test_shape ,
@@ -804,7 +808,7 @@ def test_dynamic_quant(self):
804808 m = nn .Sequential (nn .Linear (K , N ))
805809
806810 y_ref = m (x )
807- apply_dynamic_quant ( m )
811+ quantize ( m , get_apply_int8dyn_quant () )
808812 y_test = m (x )
809813
810814 sqnr = compute_error (y_ref , y_test )
@@ -818,7 +822,7 @@ def test_weight_only_quant(self):
818822 x = torch .randn (* x_shape )
819823 m = nn .Sequential (nn .Linear (4 , 5 ))
820824 y_ref = m (x )
821- apply_weight_only_int8_quant ( m )
825+ quantize ( m , get_apply_int8wo_quant () )
822826 y_wo = m (x )
823827 sqnr = compute_error (y_ref , y_wo )
824828 self .assertGreater (sqnr , 44.0 )
@@ -841,7 +845,8 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
841845 x = torch .randn (* x_shape ).to (device ).to (dtype )
842846 m = nn .Sequential (nn .Linear (4 , 5 )).to (device ).to (dtype )
843847 y_ref = m (x )
844- apply_weight_only_int8_quant (m )
848+ m = quantize (m , get_apply_int8wo_quant ())
849+ m = unwrap_tensor_subclass (m )
845850 m (x )
846851 m_c = torch .compile (m , mode = "max-autotune" )
847852 y_wo , (code ,) = run_and_get_code (m_c , x )
@@ -868,7 +873,8 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
868873 x = torch .randn (* x_shape ).to (device ).to (dtype )
869874 m = nn .Sequential (nn .Linear (4 , 5 )).to (device ).to (dtype )
870875 y_ref = m (x )
871- apply_weight_only_int8_quant (m )
876+ m = quantize (m , get_apply_int8wo_quant ())
877+ m = unwrap_tensor_subclass (m )
872878 m_c = torch .compile (m , mode = "max-autotune" )
873879 y_wo , (code ,) = run_and_get_code (m_c , x )
874880 sqnr = compute_error (y_ref , y_wo )
@@ -908,7 +914,9 @@ def forward(self, x):
908914 ref_f = model (x )
909915
910916 # save quantized state_dict
911- api (model )
917+ quantize (model , api ())
918+ unwrap_tensor_subclass (model )
919+
912920 torch .save (model .state_dict (), "test.pth" )
913921 # get quantized reference
914922 model_qc = torch .compile (model , mode = "max-autotune" )
@@ -919,11 +927,13 @@ def forward(self, x):
919927 # load model structure
920928 with torch .device ('meta' ):
921929 model = test_model ().to (dtype = test_dtype )
922- api (model )
930+ quantize (model , api ())
931+ unwrap_tensor_subclass (model )
923932
924933 # load quantized state_dict
925934 state_dict = torch .load ("test.pth" , mmap = True )
926935 os .remove ("test.pth" )
936+
927937 model .load_state_dict (state_dict , assign = True )
928938 model = model .to (device = test_device , dtype = test_dtype ).eval ()
929939
@@ -939,20 +949,20 @@ def forward(self, x):
939949 def test_save_load_dqtensors (self , device , dtype ):
940950 if device == "cpu" :
941951 self .skipTest (f"indcutor failed for cpu right now" )
942- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int8_dqtensors , device , test_dtype = dtype )
952+ self ._test_handle_save_load_meta_impl (get_apply_int8dyn_quant , device , test_dtype = dtype )
943953
944954 @parameterized .expand (COMMON_DEVICE_DTYPE )
945955 @torch .no_grad ()
946956 def test_save_load_int8woqtensors (self , device , dtype ):
947- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int8_woqtensors , device , test_dtype = dtype )
957+ self ._test_handle_save_load_meta_impl (get_apply_int8wo_quant , device , test_dtype = dtype )
948958
949959 @parameterized .expand (COMMON_DEVICE_DTYPE )
950960 @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "int4 requires torch nightly." )
951961 @torch .no_grad ()
952962 def test_save_load_int4woqtensors (self , device , dtype ):
953963 if dtype != torch .bfloat16 :
954964 self .skipTest (f"Fails for { dtype } " )
955- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int4_woqtensors , device , 20 , test_dtype = dtype )
965+ self ._test_handle_save_load_meta_impl (get_apply_int4wo_quant , device , 20 , test_dtype = dtype )
956966
957967
958968class TorchCompileUnitTest (unittest .TestCase ):
@@ -1271,8 +1281,8 @@ def forward(self, x):
12711281 model = test_model ().to (dtype = test_dtype , device = test_device ).eval ()
12721282 ref_f = model (x )
12731283
1274- kwargs = {"dtype" : test_dtype }
1275- api (model , ** kwargs )
1284+ # kwargs = {"dtype": test_dtype}
1285+ quantize (model , api () )
12761286
12771287 # running model
12781288 model (x )
@@ -1317,8 +1327,9 @@ def forward(self, x):
13171327 model = test_model ().to (dtype = test_dtype , device = test_device ).eval ()
13181328 ref_f = model (x )
13191329
1320- kwargs = {"dtype" : test_dtype }
1321- api (model , ** kwargs )
1330+ # kwargs = {"dtype": test_dtype}
1331+ model = quantize (model , api ())
1332+ model = unwrap_tensor_subclass (model )
13221333
13231334 # running model
13241335 ref = model (x )
0 commit comments