66
77"""
88This is a script to estimate the benefit from converting a `torch.nn.Linear`
9- layer to float8, by estimating the difference in e2e GPU kernel time between:
9+ layer to float8 given a single saturated GPU, by estimating the difference
10+ in e2e GPU kernel time between:
10111. bf16 gemms in fwd and bwd, and
11122. float8 gemms in fwd and bwd, and float8 overhead
1213
1314The gemm times are estimated either from direct measurements via benchmarks,
1415or with a roofline estimation based on TOPS and peak compute bandwidth of an
15- NVIDIA H100.
16+ NVIDIA H100 or B200 .
1617
1718The float8 overhead times are estimated by counting memory reads and writes
1819based on the specified float8 scaling, and estimating that we can achieve
3132 input_t @ grad_output = grad_weight
3233 KxM @ MxN => KxN
3334
34- 2. we properly model the worst-case of the current torch.compile limitations regarding
35- float8 scaling
36- 3. assume for float8 activations/gradients that torch.compile will fuse to the
35+ 2. assume for float8 activations/gradients that torch.compile will fuse to the
3736preceding op. Note that this is not always true in practice.
38- 4 . assume no AC (TODO model it)
39- 5 . assume no float8 all-gather (TODO model it)
37+ 3 . assume no AC (TODO model it)
38+ 4 . assume no float8 all-gather (TODO model it)
4039"""
4140
4241import copy
@@ -164,68 +163,60 @@ def do_matmul(A, B):
164163
165164def run (
166165 outfile : str ,
167- gemm_time_strategy : str = "benchmarks" ,
168- model_torch_compile_limitations : bool = False ,
166+ do_benchmarks : bool = True ,
169167 shape_gen_name : str = "square" ,
170168 gemm_cache_filename : Optional [str ] = None ,
171169 n_limit : Optional [int ] = None ,
172170):
173171 """
174172 Args:
175- * `gemm_time_strategy`:
176- - `benchmarks`: use benchmarks for gemm times (more accurate for all shapes)
177- - `roofline`: use roofline model for gemm times (only accurate for large shapes)
173+ * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
178174 * `shape_gen_name`: `llama`, `square`, or `sweep`
179175 * `gemm_cache_filename (optional)`: file to cache gemm benchmark results
180176 * `n_limit (optional)`: if specified, only runs `n_limit` iterations
181177 """
182178
183- print (f"gemm_time_strategy : { gemm_time_strategy } " )
179+ print (f"do_benchmarks : { do_benchmarks } " )
184180 print (f"shape_gen_name: { shape_gen_name } " )
185181
186- assert gemm_time_strategy in (
187- "benchmarks" ,
188- "roofline" ,
189- ), "`gemm_time_strategy` must be 'benchmarks' or 'roofline'"
190-
191182 M , K , N = sympy .symbols ("M K N" )
192183
193- fp8_mem_time_sympy_dyn_limit = get_float8_mem_sympy (
194- M ,
195- K ,
196- N ,
197- model_torch_compile_limitations = True ,
198- )
199184 fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy (
200185 M ,
201186 K ,
202187 N ,
203- model_torch_compile_limitations = False ,
204188 )
205189
206- if gemm_time_strategy == "roofline" :
207- bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
208- print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
209- fp8_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .float8_e4m3fn )
210- print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
211- print ()
212- else :
213- print ()
190+ bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
191+ print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
192+ fp8_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .float8_e4m3fn )
193+ print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
194+ print ()
214195
215196 headers = [
216197 "fwd_M" ,
217198 "fwd_K" ,
218199 "fwd_N" ,
219- # gemm microbenchmarks
220- "bf16_gemm_s" ,
221- "fp8_gemm_s" ,
222- # roofline memory overhead estimates
223- "fp8_oh_estimated" ,
224- "fp8_oh_ideal" ,
225- # actual e2e measurements
226- "bf16_s" ,
227- "fp8_dyn_s" ,
228- "fp8_dyn_sp" ,
200+ # roofline - gemm time (fwd + bwd, 3 gemms)
201+ "r_bf16_gemm_s" ,
202+ "r_fp8_gemm_s" ,
203+ # roofline - fp8 overhead time (by counting reads/writes in the ideal case)
204+ "r_fp8_ovhd_s" ,
205+ # roofline - fp8 gemm + fp8 overhead time (does not include LN or sigmoid)
206+ "r_fp8_gemm_and_ovhd_s" ,
207+ "r_fp8_gemm_and_ovhd_spdp" ,
208+ # benchmarks - gemm time (fwd + bwd, 3 gemms)
209+ "b_bf16_gemm_s" ,
210+ "b_fp8_gemm_s" ,
211+ # benchmarks - e2e LNLinearSigmoid time fwd + bwd
212+ "b_bf16_e2e_s" ,
213+ "b_fp8_e2e_s" ,
214+ # note that e2e speedup is not the same as the roofline speedup:
215+ # 1. roofline speedup: (bf16_gemm_time) / (fp8_gemm_time + fp8_ovhd_time)
216+ # 2. e2e speedup: (ln + bf16_gemm_time + sigmoid) / (ln + fp8_gemm_time + fp8_ovhd_time + sigmoid)
217+ # the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218+ # we don't break them out and don't have a roofline for them.
219+ "b_fp8_e2e_spdp" ,
229220 ]
230221 results = []
231222
@@ -235,7 +226,18 @@ def run(
235226 if n_limit is not None and idx >= n_limit :
236227 break
237228
238- if gemm_time_strategy == "benchmarks" :
229+ # use roofline model to estimate gemm time
230+ # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
231+ r_bf16_gemm_time_s = float (
232+ bf16_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
233+ )
234+ r_fp8_gemm_time_s = float (
235+ fp8_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
236+ )
237+
238+ # if enabled, also measured observed gemm time
239+ b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
240+ if do_benchmarks :
239241 bf16_g1 , f8_g1 = get_gemm_times (
240242 M_val , K_val , N_val , True , gemm_cache_filename
241243 )
@@ -245,60 +247,58 @@ def run(
245247 bf16_g3 , f8_g3 = get_gemm_times (
246248 K_val , M_val , N_val , False , gemm_cache_filename
247249 )
248- bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
249- fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
250- else :
251- assert gemm_time_strategy == "roofline" , "unsupported"
252- bf16_time_val = (
253- bf16_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
254- )
255- fp8_gemm_time_s = (
256- fp8_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
257- )
250+ b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251+ b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
258252
259- fp8_mem_time_dyn_limit_s = (
260- fp8_mem_time_sympy_dyn_limit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
261- )
262- fp8_mem_time_dyn_nolimit_s = (
253+ # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254+ r_fp8_ovhd_time_s = float (
263255 fp8_mem_time_sympy_dyn_nolimit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
264256 )
265257
266- # create the model
267- m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
268- x = torch .randn (
269- M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
270- ).requires_grad_ ()
258+ b_bf16_e2e_time_s , b_fp8_e2e_time_s = 0 , 0
259+ if do_benchmarks :
260+ # create the model
261+ m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
262+ x = torch .randn (
263+ M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
264+ ).requires_grad_ ()
271265
272- # get the bf16 gpu kernel time
273- torch ._dynamo .reset ()
274- m_bf16 = torch .compile (copy .deepcopy (m_orig ))
275- bf16_time_actual_s = get_gpu_kernel_time (m_bf16 , x )
266+ # get the bf16 gpu kernel time
267+ torch ._dynamo .reset ()
268+ m_bf16 = torch .compile (copy .deepcopy (m_orig ))
269+ b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x )
276270
277- # get the float8 dynamic scaling gpu kernel time
271+ # get the float8 dynamic scaling gpu kernel time
278272
279- torch ._dynamo .reset ()
280- m_fp8_dyn = convert_to_float8_training (copy .deepcopy (m_orig ))
281- m_fp8_dyn = torch .compile (m_fp8_dyn )
282- fp8_dyn_time_actual_s = get_gpu_kernel_time (m_fp8_dyn , x )
273+ torch ._dynamo .reset ()
274+ m_fp8_dyn = convert_to_float8_training (copy .deepcopy (m_orig ))
275+ m_fp8_dyn = torch .compile (m_fp8_dyn )
276+ b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x )
283277
284278 results .append (
285279 [
286280 M_val ,
287281 K_val ,
288282 N_val ,
289- # gemm microbenchmarks
290- bf16_time_val ,
291- fp8_gemm_time_s ,
292- # roofline overhead estimates
293- fp8_mem_time_dyn_limit_s ,
294- fp8_mem_time_dyn_nolimit_s ,
295- # e2e numbers
296- bf16_time_actual_s ,
297- fp8_dyn_time_actual_s ,
298- bf16_time_actual_s / fp8_dyn_time_actual_s ,
283+ # roofline - gemm
284+ r_bf16_gemm_time_s ,
285+ r_fp8_gemm_time_s ,
286+ # roofline - fp8 overhead
287+ r_fp8_ovhd_time_s ,
288+ # roofline - gemm + overhead, and speedup
289+ r_fp8_gemm_time_s + r_fp8_ovhd_time_s ,
290+ r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s ),
291+ # benchmarks - gemm
292+ b_bf16_gemm_time_s ,
293+ b_fp8_gemm_time_s ,
294+ # benchmarks - e2e, and speedup
295+ b_bf16_e2e_time_s ,
296+ b_fp8_e2e_time_s ,
297+ b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20 ),
299298 ]
300299 )
301300
301+ pd .set_option ("display.precision" , 2 )
302302 df = pd .DataFrame (results , columns = headers )
303303 print (df )
304304 df .to_csv (outfile )
0 commit comments