55from torchao .dtypes .uint4 import unpack_uint4 , pack_uint4
66
77
8- def benchmark (function , num_runs , setup = None ):
9- args = setup ()
8+ def benchmark (function , args , num_runs ):
109 torch .cuda .synchronize ()
1110 start_event = torch .cuda .Event (enable_timing = True )
1211 end_event = torch .cuda .Event (enable_timing = True )
@@ -21,207 +20,74 @@ def benchmark(function, num_runs, setup =None):
2120
2221
2322def test_vs_existing ():
24- def new_ ():
25- fake_tensor = torch .randint (0 , 2 ** 8 , (1 , 1024 , 1024 ), dtype = torch .uint8 ).cuda ()
23+ def new_ (scale ):
24+ fake_tensor = torch .randint (2 ** 8 , (1 , scale , scale ), dtype = torch .uint8 ).cuda ()
2625 packed = pack (fake_tensor , 4 , dim = 1 )
2726 unpacked = unpack (packed , 4 , dim = 1 )
28- def old_ ():
29- fake_tensor = torch .randint (0 , 2 ** 8 , (1 , 1024 , 1024 ), dtype = torch .uint8 ).cuda ()
27+ def old_ (scale ):
28+ fake_tensor = torch .randint (2 ** 8 , (1 , scale , scale ), dtype = torch .uint8 ).cuda ()
3029 packed = pack_uint4 (fake_tensor )
3130 unpacked = unpack_uint4 (packed )
32- new_ = torch .compile (new_ , fullgraph = True )
33- old_ = torch .compile (old_ , fullgraph = True )
34- new_ ()
35- old_ ()
36- print (f"new: { benchmark (new_ , 1000 )} ms " )
37- print (f"old: { benchmark (old_ , 1000 )} ms" )
38-
31+
3932
40- def test_iso_bitpack ():
41- def load4x (scale = 1024 ):
42- fake_tensor = torch .randint (0 , 2 ** 8 , (1 , 4 * scale ,scale ), dtype = torch .uint8 ).cuda ()
33+ for scale in [256 ,512 , 1024 , 2048 ,4096 , 8192 ]:
34+ new_ = torch .compile (new_ , fullgraph = True )
35+ old_ = torch .compile (old_ , fullgraph = True )
36+ new_ (scale )
37+ old_ (scale )
38+ print ("scale: " , scale )
39+ print (f"new: { benchmark (new_ ,[scale ], 10 )} ms " )
40+ print (f"old: { benchmark (old_ ,[scale ], 10 )} ms" )
41+
42+
43+ def compare_to_fp16 ():
44+ class Linear16 (torch .nn .Module ):
45+ def __init__ (self , scale ):
46+ super ().__init__ ()
47+ scale += scale % 2
48+ self .l1 = torch .nn .Linear (scale * 2 , scale , bias = False ,dtype = torch .float16 ).cuda ()
49+ self .l2 = torch .nn .Linear (scale , scale // 2 , bias = False ,dtype = torch .float16 ).cuda ()
50+
51+ def forward (self , x ):
52+ return self .l2 (self .l1 (x ))
4353
44- def load2x (scale = 1024 ):
45- fake_tensor = torch .randint (0 , 2 ** 8 , (1 , 2 * scale ,scale ), dtype = torch .uint8 ).cuda ()
46-
47- def loadx (scale = 1024 ):
48- fake_tensor = torch .randint (0 , 2 ** 8 , (1 , scale ,scale ), dtype = torch .uint8 ).cuda ()
54+ class W4A16_symmetric_weight_only (torch .nn .Module ):
55+ def __init__ (self , scale ):
56+ super ().__init__ ()
57+ assert scale % 4 == 0
58+ self .l1 = torch .randint (2 ** 8 ,(scale , scale ), dtype = torch .uint8 ).cuda ()
59+ self .s1 = torch .tensor ((scale ),dtype = torch .float16 ).cuda ()
60+ self .l2 = torch .randint (2 ** 8 ,(scale // 2 , scale // 4 ), dtype = torch .uint8 ).cuda ()
61+ self .s2 = torch .tensor ((scale // 4 ),dtype = torch .float16 ).cuda ()
4962
50- def unpack8to2 (scale = 1024 ):
51- fake_tensor = torch .randint (0 , 2 ** 8 , (1 , scale ,scale ), dtype = torch .uint8 ).cuda ()
52- unpacked_tensor = unpack_c (fake_tensor , 2 , dim = 1 )
5363
54- def unpack8to4 ( scale = 1024 ):
55- fake_tensor = torch . randint ( 0 , 2 ** 8 , ( 1 , scale , scale ), dtype = torch .uint8 ). cuda ( )
56- unpacked_tensor = unpack_c ( fake_tensor , 4 , dim = 1 )
57-
58- def t8to4wmm ( scale = 1024 ):
59- fake_tensor = torch . randint ( 0 , 2 ** 8 , ( 8 , 1024 , 1024 ), dtype = torch . uint8 ). cuda ()
60- unpacked_tensor = unpack_c ( fake_tensor , 4 , dim = 1 )
64+ def forward ( self , x ):
65+ w = unpack ( self . l1 . detach (), 4 , output_dtype = torch .float16 )
66+ x = x * self . s1
67+ x = x @ w
68+ w = unpack ( self . l2 . detach (), 4 , output_dtype = torch . float16 )
69+ x = x * self . s2
70+ x = x @ w
6171
62- torch ._dynamo .config .specialize_int = True
63- # _unpack_c = torch.compile(_unpack, fullgraph=True)
64- unpack_c = torch .compile (unpack , fullgraph = True )
65-
66- scale = [16 ,64 ,256 ,1024 ,4096 ]
67- load4x_times = []
68- unpack8to2_times = []
69- load2x_times = []
70- unpack8to4_times = []
71- for s in scale :
72- res = benchmark (load4x , 50 , scale = s )
73- load4x_times .append (res )
74- print (f"load(1, { 4 * s } ,{ s } ) time: { res } ms" )
75-
76- res = benchmark (unpack8to2 , 50 , scale = s )
77- unpack8to2_times .append (res )
78- print (f"load(1, { s } ,{ s } ) unpack uint2 time: { res } ms" )
72+ return x
73+
74+ torch ._dynamo .config .specialize_int = True
75+ for scale in [256 ,512 , 1024 , 2048 ,4096 , 8192 ]:
76+ a = Linear16 (scale )
77+ b = W4A16_symmetric_weight_only (scale )
78+ # a = torch.compile(a, fullgraph=True)
79+ b = torch .compile (b , fullgraph = True )
7980
80- res = benchmark (load2x , 50 , scale = s )
81- load2x_times .append (res )
82- print (f"load(1, { 2 * s } ,{ s } ) time: { res } ms" )
83-
84- res = benchmark (unpack8to4 , 50 , scale = s )
85- unpack8to4_times .append (res )
86- print (f"load(1, { s } ,{ s } ) unpack uint4 time: { res } ms" )
87- print ()
88-
89- # import matplotlib.pyplot as plt
90- # plt.plot(scale, load4x_times, label="load(1, 4x, x)")
91- # plt.plot(scale, unpack8to2_times, label="unpack uint8 to uint2")
92- # plt.plot(scale, load2x_times, label="load(1, 2x, x)")
93- # plt.plot(scale, unpack8to4_times, label="unpack uint8 to uint4")
94- # plt.xlabel("scale")
95- # plt.ylabel("time (ms)")
96- # plt.yscale("log")
97- # plt.legend()
98- # plt.savefig("benchmark_bitpacking.png")
99-
100-
101- def test_vs_hqqpack ():
102- #requires hqq to be installed
103- import hqq
104- import hqq .core .quantize as hqq_quantize
105- HQQLinear = hqq_quantize .HQQLinear
106- BaseQuantizeConfig = hqq_quantize .BaseQuantizeConfig
107- from torchao .prototype .hqq import pack_2xint4 , triton_mixed_mm
108-
109- BASE_QUANT_CONFIG = {
110- "optimize" : True ,
111- "view_as_float" : False ,
112- "nbits" : 4 ,
113- "bitpack" : False ,
114- "axis" : 1 ,
115- }
81+ test_input = torch .randn (scale * 2 , dtype = torch .float16 ).cuda ()
82+ forward_args = [test_input ]
83+ b .forward (test_input )
84+ print ("scale: " , scale )
85+ print ("fp16 time: " , benchmark (a .forward , forward_args , 100 ))
86+ print ("uint4 time: " , benchmark (b .forward , forward_args , 100 ))
11687
117- def mixed_mm (
118- shape , group_size , axis , dtype , transposed , kernel_type , quant_dtype = torch .uint8 , pack_fn = True
119- ):
120- qcfg = {
121- ** BASE_QUANT_CONFIG ,
122- ** dict (group_size = group_size , axis = axis ),
123- }
124- M , N , K = shape
125-
126- linear = torch .nn .Linear (K , N , bias = False , dtype = dtype , device = "cuda" )
127-
128- quant_config = BaseQuantizeConfig (
129- quant_zero = False , quant_scale = False , offload_meta = False , view_as_float = False
130- )
131- quant_config .update ({"weight_quant_params" : qcfg })
132- hqq_linear = HQQLinear (linear , quant_config , compute_dtype = dtype , del_orig = False )
133- W_q , meta = hqq_linear .W_q , hqq_linear .meta
134- W_q = W_q .to (dtype = quant_dtype )
135- W_q = (
136- W_q .reshape (meta ["shape" ])
137- if quant_config ["weight_quant_params" ]["bitpack" ] == False
138- else W_q
139- )
140- W_dq = hqq_linear .dequantize ()
141-
142- scales , zeros = meta ["scale" ], meta ["zero" ]
143- scales = scales .reshape (N , - 1 )
144- zeros = zeros .reshape (N , - 1 )
145- if pack_fn :
146- packed_w = pack (W_q .T ,4 ,dim = 0 ,order = False )
147- else :
148- packed_w = pack_2xint4 (W_q .T )
149-
150- if transposed :
151- x = torch .randn (M , N , dtype = dtype , device = "cuda" )
152- hqq_out = x @ W_dq
153-
154- tt_out = triton_mixed_mm (
155- x ,
156- packed_w ,
157- scales .T ,
158- zeros .T ,
159- transposed = True ,
160- group_size = group_size ,
161- fp8_fast_accum = False ,
162- kernel_type = kernel_type ,
163- )
164-
165- else :
166- x = torch .randn (M , K , dtype = dtype , device = "cuda" )
167- hqq_out = x @ W_dq .T
168-
169- tt_out = triton_mixed_mm (
170- x ,
171- packed_w ,
172- scales .T ,
173- zeros .T ,
174- transposed = False ,
175- group_size = group_size ,
176- fp8_fast_accum = False ,
177- kernel_type = kernel_type ,
178- )
179-
180- shapes = [
181- [16 , 128 , 128 ],
182- [16 , 4096 , 4096 ],
183- ]
184- group_sizes = [64 , 128 ]
185- shape = [16 , 128 , 128 ]
186- group_size = 64
187- pack = torch .compile (pack , fullgraph = True )
188- for i in range (2 ):
189- shape = shapes [i ]
190- group_size = group_sizes [i ]
191- print ("linear layer size: " , shape )
192- print ("group size: " , group_size )
193- # run once to compile
194- test_mixed_mm (
195- shape ,
196- group_size ,
197- 1 ,
198- torch .float16 ,
199- True ,
200- "compute_bound" ,
201- torch .uint8 ,
202- )
203- # shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8
204- print ("pack time (ms): " , benchmark (test_mixed_mm , 100 ,
205- shape ,
206- group_size ,
207- 1 ,
208- torch .float16 ,
209- True ,
210- "compute_bound" ,
211- torch .uint8 ))
21288
213- print ("pack_2xint4 time (ms): " , benchmark (test_mixed_mm , 100 ,
214- shape ,
215- group_size ,
216- 1 ,
217- torch .float16 ,
218- True ,
219- "compute_bound" , #max autotune doesnt work?
220- torch .uint8 ,
221- pack_fn = False ))
222- print ("" )
223-
224-
89+
22590if __name__ == "__main__" :
91+ compare_to_fp16 ()
22692 test_vs_existing ()
227-
93+
0 commit comments