2020 DynamicallyPerAxisQuantizedLinear ,
2121)
2222from torchao .quantization .quant_api import (
23- apply_dynamic_quant ,
24- apply_weight_only_int8_quant ,
23+ int4wo ,
24+ int8wo ,
25+ int8da_int8w ,
26+ quantize ,
27+ _replace_with_custom_fn_if_matches_filter ,
28+ )
29+ # APIs to be deprecated (used for torch 2.2.2 and 2.3)
30+ from torchao .quantization .quant_api import (
2531 change_linear_weights_to_int8_dqtensors ,
2632 change_linear_weights_to_int8_woqtensors ,
2733 change_linear_weights_to_int4_woqtensors ,
28- _replace_with_custom_fn_if_matches_filter ,
2934)
3035from torchao .quantization .quant_primitives import (
3136 safe_int_mm ,
7378from parameterized import parameterized
7479import itertools
7580import logging
76- from torchao .utils import TORCH_VERSION_AFTER_2_3 , TORCH_VERSION_AFTER_2_4 , is_fbcode
81+ from torchao .utils import (
82+ TORCH_VERSION_AFTER_2_3 ,
83+ TORCH_VERSION_AFTER_2_4 ,
84+ unwrap_tensor_subclass ,
85+ is_fbcode ,
86+ )
7787
7888logger = logging .getLogger ("INFO" )
7989
8090torch .manual_seed (0 )
8191config .cache_size_limit = 100
8292
83- # TODO: use this to reduce the number of tests
84- TENSOR_SUBCLASS_APIS = [
85- change_linear_weights_to_int8_dqtensors ,
86- change_linear_weights_to_int8_woqtensors ,
87- change_linear_weights_to_int4_woqtensors ,
88- ]
89-
9093COMMON_DEVICES = ["cpu" , "cuda" ]
9194
9295COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
9396
9497COMMON_DEVICE_DTYPE = list (itertools .product (COMMON_DEVICES , COMMON_DTYPES )).copy ()
9598
99+ def _int8wo_api (mod ):
100+ if TORCH_VERSION_AFTER_2_4 :
101+ quantize (mod , int8wo ())
102+ unwrap_tensor_subclass (mod )
103+ else :
104+ change_linear_weights_to_int8_woqtensors (mod )
105+
106+ def _int8da_int8w_api (mod ):
107+ if TORCH_VERSION_AFTER_2_4 :
108+ quantize (mod , int8da_int8w ())
109+ unwrap_tensor_subclass (mod )
110+ else :
111+ change_linear_weights_to_int8_dqtensors (mod )
112+
113+ def _int4wo_api (mod ):
114+ if TORCH_VERSION_AFTER_2_4 :
115+ quantize (mod , int4wo ())
116+ unwrap_tensor_subclass (mod )
117+ else :
118+ change_linear_weights_to_int4_woqtensors (mod )
119+
120+ # TODO: use this to reduce the number of tests
121+ TENSOR_SUBCLASS_APIS = [
122+ _int8wo_api ,
123+ _int8da_int8w_api ,
124+ _int4wo_api ,
125+ ]
126+
127+
96128def combine_parameters (a , b ):
97129 new_tuples = []
98130 for (tuple1 , tuple2 ) in itertools .product (a , b ):
@@ -756,14 +788,14 @@ def _test_lin_weight_subclass_api_impl(
756788 @unittest .skipIf (TORCH_VERSION_AFTER_2_4 , "skip because there is some bug in inductor codegen" )
757789 def test_int8_dynamic_quant_subclass_api (self , device , dtype ):
758790 self ._test_lin_weight_subclass_api_impl (
759- change_linear_weights_to_int8_dqtensors , device , 35 , test_dtype = dtype
791+ _int8da_int8w_api , device , 35 , test_dtype = dtype
760792 )
761793
762794 @parameterized .expand (COMMON_DEVICE_DTYPE )
763795 @unittest .skipIf (is_fbcode (), "broken in fbcode" )
764796 def test_int8_weight_only_quant_subclass_api (self , device , dtype ):
765797 self ._test_lin_weight_subclass_api_impl (
766- change_linear_weights_to_int8_woqtensors , device , 40 , test_dtype = dtype
798+ _int8wo_api , device , 40 , test_dtype = dtype
767799 )
768800
769801 @parameterized .expand (COMMON_DEVICE_DTYPE )
@@ -773,7 +805,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
773805 self .skipTest (f"Fails for { dtype } " )
774806 for test_shape in ([(16 , 1024 , 16 )] + ([(1 , 1024 , 256 )] if device == 'cuda' else [])):
775807 self ._test_lin_weight_subclass_api_impl (
776- change_linear_weights_to_int4_woqtensors ,
808+ _int4wo_api ,
777809 device ,
778810 15 ,
779811 test_shape = test_shape ,
@@ -789,8 +821,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
789821 for groupsize in [64 , 32 ]:
790822 for inner_k_tiles in [4 , 2 ]:
791823 kwargs = {"groupsize" : groupsize , "inner_k_tiles" : inner_k_tiles }
824+
825+ def api (mod ):
826+ if TORCH_VERSION_AFTER_2_4 :
827+ quantize (mod , int4wo (** kwargs ))
828+ unwrap_tensor_subclass (mod )
829+ else :
830+ change_linear_weights_to_int4_woqtensors (mod , ** kwargs )
831+
792832 self ._test_lin_weight_subclass_api_impl (
793- lambda mod : change_linear_weights_to_int4_woqtensors ( mod , ** kwargs ) ,
833+ api ,
794834 device ,
795835 15 ,
796836 test_shape = test_shape ,
@@ -805,7 +845,7 @@ def test_dynamic_quant(self):
805845 m = nn .Sequential (nn .Linear (K , N ))
806846
807847 y_ref = m (x )
808- apply_dynamic_quant ( m )
848+ quantize ( m , int8da_int8w () )
809849 y_test = m (x )
810850
811851 sqnr = compute_error (y_ref , y_test )
@@ -819,7 +859,7 @@ def test_weight_only_quant(self):
819859 x = torch .randn (* x_shape )
820860 m = nn .Sequential (nn .Linear (4 , 5 ))
821861 y_ref = m (x )
822- apply_weight_only_int8_quant (m )
862+ _int8wo_api (m )
823863 y_wo = m (x )
824864 sqnr = compute_error (y_ref , y_wo )
825865 self .assertGreater (sqnr , 44.0 )
@@ -842,7 +882,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
842882 x = torch .randn (* x_shape ).to (device ).to (dtype )
843883 m = nn .Sequential (nn .Linear (4 , 5 )).to (device ).to (dtype )
844884 y_ref = m (x )
845- apply_weight_only_int8_quant (m )
885+ _int8wo_api (m )
846886 m (x )
847887 m_c = torch .compile (m , mode = "max-autotune" )
848888 y_wo , (code ,) = run_and_get_code (m_c , x )
@@ -869,7 +909,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
869909 x = torch .randn (* x_shape ).to (device ).to (dtype )
870910 m = nn .Sequential (nn .Linear (4 , 5 )).to (device ).to (dtype )
871911 y_ref = m (x )
872- apply_weight_only_int8_quant (m )
912+ _int8wo_api (m )
873913 m_c = torch .compile (m , mode = "max-autotune" )
874914 y_wo , (code ,) = run_and_get_code (m_c , x )
875915 sqnr = compute_error (y_ref , y_wo )
@@ -910,6 +950,7 @@ def forward(self, x):
910950
911951 # save quantized state_dict
912952 api (model )
953+
913954 torch .save (model .state_dict (), "test.pth" )
914955 # get quantized reference
915956 model_qc = torch .compile (model , mode = "max-autotune" )
@@ -925,6 +966,7 @@ def forward(self, x):
925966 # load quantized state_dict
926967 state_dict = torch .load ("test.pth" , mmap = True )
927968 os .remove ("test.pth" )
969+
928970 model .load_state_dict (state_dict , assign = True )
929971 model = model .to (device = test_device , dtype = test_dtype ).eval ()
930972
@@ -941,21 +983,21 @@ def forward(self, x):
941983 def test_save_load_dqtensors (self , device , dtype ):
942984 if device == "cpu" :
943985 self .skipTest (f"indcutor failed for cpu right now" )
944- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int8_dqtensors , device , test_dtype = dtype )
986+ self ._test_handle_save_load_meta_impl (_int8da_int8w_api , device , test_dtype = dtype )
945987
946988 @parameterized .expand (COMMON_DEVICE_DTYPE )
947989 @torch .no_grad ()
948990 @unittest .skipIf (is_fbcode (), "broken in fbcode" )
949991 def test_save_load_int8woqtensors (self , device , dtype ):
950- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int8_woqtensors , device , test_dtype = dtype )
992+ self ._test_handle_save_load_meta_impl (_int8wo_api , device , test_dtype = dtype )
951993
952994 @parameterized .expand (COMMON_DEVICE_DTYPE )
953995 @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "int4 requires torch nightly." )
954996 @torch .no_grad ()
955997 def test_save_load_int4woqtensors (self , device , dtype ):
956998 if dtype != torch .bfloat16 :
957999 self .skipTest (f"Fails for { dtype } " )
958- self ._test_handle_save_load_meta_impl (change_linear_weights_to_int4_woqtensors , device , 20 , test_dtype = dtype )
1000+ self ._test_handle_save_load_meta_impl (_int4wo_api , device , 20 , test_dtype = dtype )
9591001
9601002
9611003class TorchCompileUnitTest (unittest .TestCase ):
@@ -1275,8 +1317,7 @@ def forward(self, x):
12751317 model = test_model ().to (dtype = test_dtype , device = test_device ).eval ()
12761318 ref_f = model (x )
12771319
1278- kwargs = {"dtype" : test_dtype }
1279- api (model , ** kwargs )
1320+ api (model )
12801321
12811322 # running model
12821323 model (x )
@@ -1321,8 +1362,7 @@ def forward(self, x):
13211362 model = test_model ().to (dtype = test_dtype , device = test_device ).eval ()
13221363 ref_f = model (x )
13231364
1324- kwargs = {"dtype" : test_dtype }
1325- api (model , ** kwargs )
1365+ api (model )
13261366
13271367 # running model
13281368 ref = model (x )
0 commit comments