@@ -145,13 +145,15 @@ def _int8da_int8w_api(
145145 change_linear_weights_to_int8_dqtensors (mod )
146146
147147
148- def _int4wo_api (mod ):
148+ def _int4wo_api (mod , use_hqq = False ):
149149 if (
150150 is_device (next (mod .parameters ()).device .type , "cpu" )
151151 and TORCH_VERSION_AT_LEAST_2_6
152152 ):
153153 quantize_ (
154- mod , int4_weight_only (layout = Int4CPULayout ()), set_inductor_config = False
154+ mod ,
155+ int4_weight_only (layout = Int4CPULayout (), use_hqq = use_hqq ),
156+ set_inductor_config = False ,
155157 )
156158 unwrap_tensor_subclass (mod )
157159 elif TORCH_VERSION_AT_LEAST_2_4 :
@@ -1049,8 +1051,6 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
10491051 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
10501052 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
10511053 def test_int4_weight_only_quant_subclass_api (self , device , dtype ):
1052- if device == "cpu" :
1053- self .skipTest (f"Temporarily skipping for { device } " )
10541054 if dtype != torch .bfloat16 :
10551055 self .skipTest (f"Fails for { dtype } " )
10561056 for test_shape in [(16 , 1024 , 16 )] + (
@@ -1060,6 +1060,20 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
10601060 _int4wo_api , device , 15 , test_shape = test_shape , test_dtype = dtype
10611061 )
10621062
1063+ @parameterized .expand (COMMON_DEVICE_DTYPE )
1064+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_6 , "int4 hqq requires torch nightly." )
1065+ def test_int4_weight_only_hqq_quant_subclass_api (self , device , dtype ):
1066+ if dtype != torch .bfloat16 :
1067+ self .skipTest (f"Fails for { dtype } " )
1068+ for test_shape in [(16 , 1024 , 16 ), (1 , 1024 , 256 )]:
1069+ api = partial (
1070+ _int4wo_api ,
1071+ use_hqq = True ,
1072+ )
1073+ self ._test_lin_weight_subclass_api_impl (
1074+ api , device , 15 , test_shape = test_shape , test_dtype = dtype
1075+ )
1076+
10631077 @parameterized .expand (COMMON_DEVICE_DTYPE )
10641078 @unittest .skipIf (
10651079 not TORCH_VERSION_AT_LEAST_2_5 , "gemlite tests needs torch 2.5 or greater"
@@ -1111,8 +1125,6 @@ def test_gemlite_layout(self, device, dtype):
11111125 # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
11121126 @skip_if_rocm ("ROCm enablement in progress" )
11131127 def test_int4_weight_only_quant_subclass_api_grouped (self , device , dtype ):
1114- if device == "cpu" :
1115- self .skipTest (f"Temporarily skipping for { device } " )
11161128 if dtype != torch .bfloat16 :
11171129 self .skipTest (f"Fails for { dtype } " )
11181130 layout_list = []
0 commit comments