@@ -51,31 +51,33 @@ def calculate_diff(
5151):
5252    """Calculate the difference between Inductor and CUDA implementations.""" 
5353    device  =  torch .device ("cuda" )
54-     x  =  torch .rand ((batch_size   *   hidden_size ,  4096 ), dtype = dtype , device = device )
54+     x  =  torch .randn ((batch_size ,  hidden_size ), dtype = dtype , device = device )
5555
5656    quant_fp8  =  QuantFP8 (False , group_shape , column_major_scales = False )
5757
5858    torch_out , torch_scale  =  bench_compile (quant_fp8 .forward_native )(x )
5959    torch_eager_out , torch_eager_scale  =  quant_fp8 .forward_native (x )
6060    cuda_out , cuda_scale  =  quant_fp8 .forward_cuda (x )
6161
62-     out_allclose  =  lambda  o1 , o2 : torch .allclose (
63-         o1 .to (torch .float32 ),
64-         o2 .to (torch .float32 ),
65-         rtol = 1e-3 ,
66-         atol = 1e-5 ,
67-     )
68-     scale_allclose  =  lambda  s1 , s2 : torch .allclose (s1 , s2 , rtol = 1e-3 , atol = 1e-5 )
69- 
70-     if  (
71-         out_allclose (cuda_out , torch_out )
72-         and  scale_allclose (cuda_scale , torch_scale )
73-         and  out_allclose (cuda_out , torch_eager_out )
74-         and  scale_allclose (cuda_scale , torch_eager_scale )
75-     ):
62+     try :
63+         torch .testing .assert_close (
64+             cuda_out .to (torch .float32 ),
65+             torch_out .to (torch .float32 ),
66+             rtol = 1e-3 ,
67+             atol = 1e-5 ,
68+         )
69+         torch .testing .assert_close (cuda_scale , torch_scale , rtol = 1e-3 , atol = 1e-5 )
70+         torch .testing .assert_close (
71+             cuda_out .to (torch .float32 ),
72+             torch_eager_out .to (torch .float32 ),
73+             rtol = 1e-3 ,
74+             atol = 1e-5 ,
75+         )
76+         torch .testing .assert_close (cuda_scale , torch_eager_scale , rtol = 1e-3 , atol = 1e-5 )
7677        print ("✅ All implementations match" )
77-     else :
78+     except   AssertionError   as   e :
7879        print ("❌ Implementations differ" )
80+         print (e )
7981
8082
8183configs  =  []
@@ -91,7 +93,7 @@ def benchmark_quantization(
9193):
9294    device  =  torch .device ("cuda" )
9395
94-     x  =  torch .randn (batch_size   *   hidden_size ,  4096 , device = device , dtype = dtype )
96+     x  =  torch .randn (batch_size ,  hidden_size , device = device , dtype = dtype )
9597
9698    quantiles  =  [0.5 , 0.2 , 0.8 ]
9799    quant_fp8  =  QuantFP8 (False , group_shape , column_major_scales = col_major )
@@ -157,21 +159,21 @@ def geo_speedup(group: pd.DataFrame) -> pd.Series:
157159    )
158160    parser .add_argument ("-c" , "--check" , action = "store_true" )
159161    parser .add_argument (
160-         "--dtype" , type = str , choices = ["half" , "bfloat16" , "float" ], default = "half " 
162+         "--dtype" , type = str , choices = ["half" , "bfloat16" , "float" ], default = "bfloat16 " 
161163    )
162164    parser .add_argument (
163165        "--hidden-sizes" ,
164166        type = int ,
165167        nargs = "+" ,
166-         default = None ,
167-         help = "Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096) " ,
168+         default = [ 896 ,  1024 ,  2048 ,  4096 ,  7168 ] ,
169+         help = "Hidden sizes to benchmark" ,
168170    )
169171    parser .add_argument (
170172        "--batch-sizes" ,
171173        type = int ,
172174        nargs = "+" ,
173-         default = None ,
174-         help = "Batch sizes to benchmark (default: 1,16,32,64,128) " ,
175+         default = [ 1 ,  16 ,  128 ,  512 ,  1024 ] ,
176+         help = "Batch sizes to benchmark" ,
175177    )
176178    parser .add_argument (
177179        "--group-sizes" ,
@@ -192,8 +194,8 @@ def geo_speedup(group: pd.DataFrame) -> pd.Series:
192194
193195    dtype  =  STR_DTYPE_TO_TORCH_DTYPE [args .dtype ]
194196
195-     hidden_sizes  =  args .hidden_sizes   or  [ 1 ,  16 ,  64 ,  128 ,  256 ,  512 ,  1024 ,  2048 ,  4096 ] 
196-     batch_sizes  =  args .batch_sizes   or  [ 1 ,  16 ,  32 ,  64 ,  128 ] 
197+     hidden_sizes  =  args .hidden_sizes 
198+     batch_sizes  =  args .batch_sizes 
197199
198200    if  args .group_sizes  is  not None :
199201        group_shapes  =  []
0 commit comments