Skip to content

Commit d83f3f7

Browse files
authored
Fixes and updates to bench_per_token_quant_fp8 (#25591)
Signed-off-by: Michael Goin <[email protected]>
1 parent 302eb94 commit d83f3f7

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

benchmarks/kernels/bench_per_token_quant_fp8.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,33 @@ def calculate_diff(
5151
):
5252
"""Calculate the difference between Inductor and CUDA implementations."""
5353
device = torch.device("cuda")
54-
x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device)
54+
x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device)
5555

5656
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)
5757

5858
torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x)
5959
torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
6060
cuda_out, cuda_scale = quant_fp8.forward_cuda(x)
6161

62-
out_allclose = lambda o1, o2: torch.allclose(
63-
o1.to(torch.float32),
64-
o2.to(torch.float32),
65-
rtol=1e-3,
66-
atol=1e-5,
67-
)
68-
scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5)
69-
70-
if (
71-
out_allclose(cuda_out, torch_out)
72-
and scale_allclose(cuda_scale, torch_scale)
73-
and out_allclose(cuda_out, torch_eager_out)
74-
and scale_allclose(cuda_scale, torch_eager_scale)
75-
):
62+
try:
63+
torch.testing.assert_close(
64+
cuda_out.to(torch.float32),
65+
torch_out.to(torch.float32),
66+
rtol=1e-3,
67+
atol=1e-5,
68+
)
69+
torch.testing.assert_close(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5)
70+
torch.testing.assert_close(
71+
cuda_out.to(torch.float32),
72+
torch_eager_out.to(torch.float32),
73+
rtol=1e-3,
74+
atol=1e-5,
75+
)
76+
torch.testing.assert_close(cuda_scale, torch_eager_scale, rtol=1e-3, atol=1e-5)
7677
print("✅ All implementations match")
77-
else:
78+
except AssertionError as e:
7879
print("❌ Implementations differ")
80+
print(e)
7981

8082

8183
configs = []
@@ -91,7 +93,7 @@ def benchmark_quantization(
9193
):
9294
device = torch.device("cuda")
9395

94-
x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype)
96+
x = torch.randn(batch_size, hidden_size, device=device, dtype=dtype)
9597

9698
quantiles = [0.5, 0.2, 0.8]
9799
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)
@@ -157,21 +159,21 @@ def geo_speedup(group: pd.DataFrame) -> pd.Series:
157159
)
158160
parser.add_argument("-c", "--check", action="store_true")
159161
parser.add_argument(
160-
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
162+
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
161163
)
162164
parser.add_argument(
163165
"--hidden-sizes",
164166
type=int,
165167
nargs="+",
166-
default=None,
167-
help="Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)",
168+
default=[896, 1024, 2048, 4096, 7168],
169+
help="Hidden sizes to benchmark",
168170
)
169171
parser.add_argument(
170172
"--batch-sizes",
171173
type=int,
172174
nargs="+",
173-
default=None,
174-
help="Batch sizes to benchmark (default: 1,16,32,64,128)",
175+
default=[1, 16, 128, 512, 1024],
176+
help="Batch sizes to benchmark",
175177
)
176178
parser.add_argument(
177179
"--group-sizes",
@@ -192,8 +194,8 @@ def geo_speedup(group: pd.DataFrame) -> pd.Series:
192194

193195
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
194196

195-
hidden_sizes = args.hidden_sizes or [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
196-
batch_sizes = args.batch_sizes or [1, 16, 32, 64, 128]
197+
hidden_sizes = args.hidden_sizes
198+
batch_sizes = args.batch_sizes
197199

198200
if args.group_sizes is not None:
199201
group_shapes = []

0 commit comments

Comments
 (0)