77)
88from torchao .utils import (
99 TORCH_VERSION_AT_LEAST_2_4 ,
10+ TORCH_VERSION_AT_LEAST_2_5 ,
1011)
1112from torchao .quantization .quant_api import (
13+ int4_weight_only ,
14+ int8_weight_only ,
15+ int8_dynamic_activation_int8_weight ,
16+ quantize_ ,
1217 _replace_with_custom_fn_if_matches_filter ,
1318)
1419import copy
1520
21+ def _int8wo_api (mod , ** kwargs ):
22+ if TORCH_VERSION_AT_LEAST_2_4 :
23+ quantize_ (mod , int8_weight_only (** kwargs ), set_inductor_config = False )
24+ if not TORCH_VERSION_AT_LEAST_2_5 :
25+ unwrap_tensor_subclass (mod )
26+ else :
27+ change_linear_weights_to_int8_woqtensors (mod , ** kwargs )
28+
29+ def _int8da_int8w_api (mod , ** kwargs ):
30+ if TORCH_VERSION_AT_LEAST_2_4 :
31+ quantize_ (mod , int8_dynamic_activation_int8_weight (** kwargs ), set_inductor_config = False )
32+ if not TORCH_VERSION_AT_LEAST_2_5 :
33+ unwrap_tensor_subclass (mod )
34+ else :
35+ change_linear_weights_to_int8_dqtensors (mod , ** kwargs )
36+
37+ def _int4wo_api (mod , ** kwargs ):
38+ if TORCH_VERSION_AT_LEAST_2_4 :
39+ kwargs_copy = kwargs .copy ()
40+ if "groupsize" in kwargs_copy :
41+ kwargs_copy ["group_size" ] = kwargs_copy ["groupsize" ]
42+ del kwargs_copy ["groupsize" ]
43+ quantize_ (mod , int4_weight_only (** kwargs_copy ), set_inductor_config = False )
44+ if not TORCH_VERSION_AT_LEAST_2_5 :
45+ unwrap_tensor_subclass (mod )
46+ else :
47+ change_linear_weights_to_int4_woqtensors (mod , ** kwargs )
48+
1649class ToyLinearModel (torch .nn .Module ):
17- def __init__ (self , m = 64 , n = 32 , k = 64 ):
50+ """Single linear for m * k * n problem size
51+ """
52+ def __init__ (self , m = 64 , n = 32 , k = 64 , has_bias = False , dtype = torch .float , device = "cuda" ):
1853 super ().__init__ ()
19- self .linear1 = torch .nn .Linear (m , n , bias = False ).to (torch .float )
20- self .linear2 = torch .nn .Linear (n , k , bias = False ).to (torch .float )
54+ self .m = m
55+ self .dtype = dtype
56+ self .device = device
57+ self .linear = torch .nn .Linear (k , n , bias = has_bias ).to (dtype = self .dtype , device = self .device )
2158
22- def example_inputs (self , batch_size = 1 , dtype = torch . float , device = "cpu" ):
23- return (torch .randn (batch_size , self .linear1 .in_features , dtype = dtype , device = device ),)
59+ def example_inputs (self ):
60+ return (torch .randn (self . m , self .linear .in_features , dtype = self . dtype , device = self . device ),)
2461
2562 def forward (self , x ):
26- x = self .linear1 (x )
27- x = self .linear2 (x )
63+ x = self .linear (x )
2864 return x
2965
3066def _ref_change_linear_weights_to_int8_dqtensors (model , filter_fn = None , ** kwargs ):
@@ -69,14 +105,17 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
69105_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors (Int4WeightOnlyQuantizedLinearWeight )
70106
71107
72- def _bench_quantized_tensor_subclass_perf (api , ref_api , kwargs = None ):
108+ torch ._dynamo .config .cache_size_limit = 50000
109+
110+ @torch .no_grad
111+ def _bench_quantized_tensor_subclass_perf (api , ref_api , M , N , K , kwargs = None ):
73112 if kwargs is None :
74113 kwargs = {}
75114
76- m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
115+ m = ToyLinearModel (M , N , K , has_bias = True , dtype = torch .bfloat16 , device = "cuda" ).eval ()
116+ m_bf16 = copy .deepcopy (m )
77117 m_ref = copy .deepcopy (m )
78- # setting batch_size to 20 to be compatible with the kernel
79- example_inputs = m .example_inputs (batch_size = 20 , dtype = torch .bfloat16 , device = "cuda" )
118+ example_inputs = m .example_inputs ()
80119
81120 api (m , ** kwargs )
82121
@@ -91,27 +130,41 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
91130 # perf comparison
92131 from torchao .utils import benchmark_model
93132 # warmup
94- WARMUP = 5
133+ WARMUP = 20
95134 RUNS = 100
96- m = torch .compile (m , mode = 'max-autotune' , fullgraph = True )
97-
98- benchmark_model (m , WARMUP , example_inputs )
99- elapsed_time = benchmark_model (m , RUNS , example_inputs )
100135
101136 m_ref = torch .compile (m_ref , mode = 'max-autotune' , fullgraph = True )
102137 benchmark_model (m_ref , WARMUP , example_inputs )
103138 ref_elapsed_time = benchmark_model (m_ref , RUNS , example_inputs )
104139
105- print (f"elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } " )
106- assert elapsed_time < 1.05 * ref_elapsed_time
140+ m = torch .compile (m , mode = 'max-autotune' , fullgraph = True )
141+ benchmark_model (m , WARMUP , example_inputs )
142+ elapsed_time = benchmark_model (m , RUNS , example_inputs )
143+
144+
145+ m_bf16 = torch .compile (m_bf16 , mode = 'max-autotune' , fullgraph = True )
146+ benchmark_model (m_bf16 , WARMUP , example_inputs )
147+ bf16_elapsed_time = benchmark_model (m_bf16 , RUNS , example_inputs )
148+
149+ print (f"{ (M , N , K )} : elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } , bf16 elapsed time: { bf16_elapsed_time } " )
107150
108151if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch .cuda .is_available ():
152+ all_shapes = [
153+ (20 , 2048 , 2048 ),
154+ ]
155+
156+ print ("_int8da_int8w_api" )
109157 from torchao .quantization .quant_api import change_linear_weights_to_int8_dqtensors
110- _bench_quantized_tensor_subclass_perf (change_linear_weights_to_int8_dqtensors , _ref_change_linear_weights_to_int8_dqtensors )
158+ for M , N , K in all_shapes :
159+ _bench_quantized_tensor_subclass_perf (_int8da_int8w_api , _ref_change_linear_weights_to_int8_dqtensors , M , N , K )
111160
161+ print ("_int8wo_api" )
112162 from torchao .quantization .quant_api import change_linear_weights_to_int8_woqtensors
113- _bench_quantized_tensor_subclass_perf (change_linear_weights_to_int8_woqtensors , _ref_change_linear_weights_to_int8_woqtensors )
163+ for M , N , K in all_shapes :
164+ _bench_quantized_tensor_subclass_perf (_int8wo_api , _ref_change_linear_weights_to_int8_woqtensors , M , N , K )
114165
166+ print ("_int4wo_api" )
115167 kwargs = {"groupsize" : 32 }
116168 from torchao .quantization .quant_api import change_linear_weights_to_int4_woqtensors
117- _bench_quantized_tensor_subclass_perf (change_linear_weights_to_int4_woqtensors , _ref_change_linear_weights_to_int4_woqtensors , kwargs )
169+ for M , N , K in all_shapes :
170+ _bench_quantized_tensor_subclass_perf (_int4wo_api , _ref_change_linear_weights_to_int4_woqtensors , M , N , K , kwargs )
0 commit comments