From 8aa430d655f76d6a62f27f34207c5ef2d948978e Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 26 Feb 2025 19:18:23 +0000 Subject: [PATCH 1/7] Add benchmark for DeepGEMM and vLLM Block FP8 Dense GEMM Signed-off-by: mgoin --- benchmarks/kernels/deepgemm/README.md | 170 +++++++++++ .../benchmark_fp8_block_dense_gemm.py | 267 ++++++++++++++++++ 2 files changed, 437 insertions(+) create mode 100644 benchmarks/kernels/deepgemm/README.md create mode 100644 benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md new file mode 100644 index 000000000000..02c5acd0be19 --- /dev/null +++ b/benchmarks/kernels/deepgemm/README.md @@ -0,0 +1,170 @@ +# DeepSeek DeepGEMM Kernels Benchmark + +This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels. + +Currently this just includes dense GEMMs and only works on Hopper GPUs. + +## Setup + +You need to install vLLM in your usual fashion, then install DeepGEMM from source: + +``` +git clone --recursive https://github.com/deepseek-ai/DeepGEMM +uv pip install -e DeepGEMM +``` + +## Usage + +``` +python benchmark_fp8_block_dense_gemm.py +INFO 02-26 19:12:16 [__init__.py:207] Automatically detected platform cuda. +===== STARTING FP8 GEMM BENCHMARK ===== +Using device: NVIDIA H100 80GB HBM3 + +=== Benchmarking shape: m=8, n=4096, k=7168 === +Running correctness check... +WARNING 02-26 19:12:19 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +DeepGEMM vs Reference difference: 0.000689 +vLLM Triton vs Reference difference: 0.000691 +vLLM CUTLASS vs Reference difference: 0.000691 +vLLM Triton vs DeepGEMM difference: 0.000011 +vLLM CUTLASS vs DeepGEMM difference: 0.000011 +DeepGEMM: 0.129 ms, 3.64 TFLOPS +vLLM Triton: 0.074 ms, 6.35 TFLOPS +vLLM CUTLASS: 0.034 ms, 13.71 TFLOPS +DeepGEMM is 1.74x faster than vLLM Triton +DeepGEMM is 3.76x faster than vLLM CUTLASS +vLLM CUTLASS is 2.16x faster than vLLM Triton + +=== Benchmarking shape: m=8, n=7168, k=18432 === +Running correctness check... +INFO 02-26 19:12:19 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. +DeepGEMM vs Reference difference: 0.000680 +vLLM Triton vs Reference difference: 0.000680 +vLLM CUTLASS vs Reference difference: 0.000680 +vLLM Triton vs DeepGEMM difference: 0.000010 +vLLM CUTLASS vs DeepGEMM difference: 0.000010 +DeepGEMM: 0.114 ms, 18.48 TFLOPS +vLLM Triton: 0.091 ms, 23.14 TFLOPS +vLLM CUTLASS: 0.082 ms, 25.86 TFLOPS +DeepGEMM is 1.25x faster than vLLM Triton +DeepGEMM is 1.40x faster than vLLM CUTLASS +vLLM CUTLASS is 1.12x faster than vLLM Triton + +=== Benchmarking shape: m=8, n=18432, k=7168 === +Running correctness check... +WARNING 02-26 19:12:19 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +DeepGEMM vs Reference difference: 0.000682 +vLLM Triton vs Reference difference: 0.000682 +vLLM CUTLASS vs Reference difference: 0.000682 +vLLM Triton vs DeepGEMM difference: 0.000005 +vLLM CUTLASS vs DeepGEMM difference: 0.000005 +DeepGEMM: 0.113 ms, 18.68 TFLOPS +vLLM Triton: 0.117 ms, 18.03 TFLOPS +vLLM CUTLASS: 0.082 ms, 25.76 TFLOPS +DeepGEMM is 0.97x slower than vLLM Triton +DeepGEMM is 1.38x faster than vLLM CUTLASS +vLLM CUTLASS is 1.43x faster than vLLM Triton + +=== Benchmarking shape: m=128, n=4096, k=7168 === +Running correctness check... +DeepGEMM vs Reference difference: 0.000682 +vLLM Triton vs Reference difference: 0.000682 +vLLM CUTLASS vs Reference difference: 0.000682 +vLLM Triton vs DeepGEMM difference: 0.000007 +vLLM CUTLASS vs DeepGEMM difference: 0.000007 +DeepGEMM: 0.114 ms, 65.79 TFLOPS +vLLM Triton: 0.091 ms, 82.65 TFLOPS +vLLM CUTLASS: 0.039 ms, 191.25 TFLOPS +DeepGEMM is 1.26x faster than vLLM Triton +DeepGEMM is 2.91x faster than vLLM CUTLASS +vLLM CUTLASS is 2.31x faster than vLLM Triton + +=== Benchmarking shape: m=128, n=7168, k=18432 === +Running correctness check... +DeepGEMM vs Reference difference: 0.000683 +vLLM Triton vs Reference difference: 0.000683 +vLLM CUTLASS vs Reference difference: 0.000683 +vLLM Triton vs DeepGEMM difference: 0.000008 +vLLM CUTLASS vs DeepGEMM difference: 0.000008 +DeepGEMM: 0.115 ms, 293.95 TFLOPS +vLLM Triton: 0.143 ms, 236.69 TFLOPS +vLLM CUTLASS: 0.093 ms, 363.23 TFLOPS +DeepGEMM is 0.81x slower than vLLM Triton +DeepGEMM is 1.24x faster than vLLM CUTLASS +vLLM CUTLASS is 1.53x faster than vLLM Triton + +=== Benchmarking shape: m=128, n=18432, k=7168 === +Running correctness check... +DeepGEMM vs Reference difference: 0.000684 +vLLM Triton vs Reference difference: 0.000684 +vLLM CUTLASS vs Reference difference: 0.000684 +vLLM Triton vs DeepGEMM difference: 0.000007 +vLLM CUTLASS vs DeepGEMM difference: 0.000007 +DeepGEMM: 0.112 ms, 301.67 TFLOPS +vLLM Triton: 0.228 ms, 148.41 TFLOPS +vLLM CUTLASS: 0.086 ms, 395.53 TFLOPS +DeepGEMM is 0.49x slower than vLLM Triton +DeepGEMM is 1.31x faster than vLLM CUTLASS +vLLM CUTLASS is 2.67x faster than vLLM Triton + +=== Benchmarking shape: m=1024, n=4096, k=7168 === +Running correctness check... +DeepGEMM vs Reference difference: 0.000683 +vLLM Triton vs Reference difference: 0.000683 +vLLM CUTLASS vs Reference difference: 0.000683 +vLLM Triton vs DeepGEMM difference: 0.000007 +vLLM CUTLASS vs DeepGEMM difference: 0.000007 +DeepGEMM: 0.171 ms, 351.94 TFLOPS +vLLM Triton: 0.241 ms, 249.66 TFLOPS +vLLM CUTLASS: 0.101 ms, 598.08 TFLOPS +DeepGEMM is 0.71x slower than vLLM Triton +DeepGEMM is 1.70x faster than vLLM CUTLASS +vLLM CUTLASS is 2.40x faster than vLLM Triton + +=== Benchmarking shape: m=1024, n=18432, k=7168 === +Running correctness check... +DeepGEMM vs Reference difference: 0.000684 +vLLM Triton vs Reference difference: 0.000684 +vLLM CUTLASS vs Reference difference: 0.000684 +vLLM Triton vs DeepGEMM difference: 0.000007 +vLLM CUTLASS vs DeepGEMM difference: 0.000007 +DeepGEMM: 0.347 ms, 780.08 TFLOPS +vLLM Triton: 0.898 ms, 301.38 TFLOPS +vLLM CUTLASS: 0.331 ms, 817.56 TFLOPS +DeepGEMM is 0.39x slower than vLLM Triton +DeepGEMM is 1.05x faster than vLLM CUTLASS +vLLM CUTLASS is 2.71x faster than vLLM Triton + +=== Benchmarking shape: m=2048, n=4096, k=7168 === +Running correctness check... +DeepGEMM vs Reference difference: 0.000683 +vLLM Triton vs Reference difference: 0.000683 +vLLM CUTLASS vs Reference difference: 0.000683 +vLLM Triton vs DeepGEMM difference: 0.000007 +vLLM CUTLASS vs DeepGEMM difference: 0.000007 +DeepGEMM: 0.321 ms, 374.33 TFLOPS +vLLM Triton: 0.461 ms, 261.05 TFLOPS +vLLM CUTLASS: 0.200 ms, 601.60 TFLOPS +DeepGEMM is 0.70x slower than vLLM Triton +DeepGEMM is 1.61x faster than vLLM CUTLASS +vLLM CUTLASS is 2.30x faster than vLLM Triton + +===== BENCHMARK SUMMARY ===== +Matrix multiplication: C[m,n] = A[m,k] @ B[n,k].T + +Average speedups: +DeepGEMM vs vLLM Triton: 1.32x faster +DeepGEMM vs vLLM CUTLASS: 0.64x slower +vLLM CUTLASS vs vLLM Triton: 2.07x faster + +Average TFLOPS: +DeepGEMM: 245.40 TFLOPS +vLLM Triton: 147.48 TFLOPS +vLLM CUTLASS: 336.95 TFLOPS + +Average accuracy difference vs reference: +DeepGEMM: 0.000683 +vLLM Triton: 0.000684 +vLLM CUTLASS: 0.000684 +``` diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py new file mode 100644 index 000000000000..7c74be91d829 --- /dev/null +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: Apache-2.0 +import time +from typing import Dict, Tuple + +# Import DeepGEMM functions +import deep_gemm +import torch +from deep_gemm import calc_diff, cell_div, get_col_major_tma_aligned_tensor + +# Import vLLM functions +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, w8a8_block_fp8_matmul) + + +# Copied from +# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9 +def per_token_cast_to_fp8( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert tensor to FP8 format with per-token scaling.""" + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to( + torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + +# Copied from +# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17 +def per_block_cast_to_fp8( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert tensor to FP8 format with per-block scaling.""" + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( + x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + +def benchmark_shape(m: int, + n: int, + k: int, + warmup: int = 10, + repeat: int = 1000) -> Dict[str, Dict[str, float]]: + """Benchmark all implementations for a specific (m, n, k) shape.""" + print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") + + # Create test tensors + A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + + # Reference result in BF16 + torch.cuda.synchronize() + C_ref = A @ B.t() + + # Pre-quantize B for all implementations + # (weights can be pre-quantized offline) + B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B) + B_vllm, B_scale_vllm = per_block_cast_to_fp8(B) + + # Block size configuration + block_size = [128, 128] + + results = {} + + # === DeepGEMM Implementation === + def deepgemm_gemm(): + # A quantization is inside the loop as it depends on activations + A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) + A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) + C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_aligned), + (B_deepgemm, B_scale_deepgemm), + C_deepgemm) + return C_deepgemm + + # === vLLM Triton Implementation === + def vllm_triton_gemm(): + # A quantization is inside the loop as it depends on activations + A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) + return w8a8_block_fp8_matmul(A_vllm, + B_vllm, + A_scale_vllm, + B_scale_vllm, + block_size, + output_dtype=torch.bfloat16) + + # === vLLM CUTLASS Implementation === + def vllm_cutlass_gemm(): + # A quantization is inside the loop as it depends on activations + A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( + A, block_size[1], column_major_scales=True) + return ops.cutlass_scaled_mm(A_vllm_cutlass, + B_vllm.T, + scale_a=A_scale_vllm_cutlass, + scale_b=B_scale_vllm.T, + out_dtype=torch.bfloat16) + + # Run correctness check first + print("Running correctness check...") + C_deepgemm = deepgemm_gemm() + C_vllm_triton = vllm_triton_gemm() + C_vllm_cutlass = vllm_cutlass_gemm() + + deepgemm_diff = calc_diff(C_deepgemm, C_ref) + vllm_triton_diff = calc_diff(C_vllm_triton, C_ref) + vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref) + + print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") + print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") + print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") + print("vLLM Triton vs DeepGEMM difference: " + f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") + print("vLLM CUTLASS vs DeepGEMM difference: " + f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") + + # Benchmark implementations + implementations = { + "DeepGEMM": deepgemm_gemm, + "vLLM Triton": vllm_triton_gemm, + "vLLM CUTLASS": vllm_cutlass_gemm + } + + for name, func in implementations.items(): + # Warmup + for _ in range(warmup): + func() + torch.cuda.synchronize() + + # Timing loop + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + func() + torch.cuda.synchronize() + end = time.time() + + # Calculate timing and TFLOPS + avg_time_ms = (end - start) / repeat * 1000 + flops = 2 * m * n * k # multiply-adds + tflops = flops / (avg_time_ms * 1e-3) / 1e12 + + results[name] = { + "time_ms": avg_time_ms, + "tflops": tflops, + "diff": { + "DeepGEMM": + deepgemm_diff if name == "DeepGEMM" else calc_diff( + func(), C_deepgemm), + "Reference": + deepgemm_diff if name == "DeepGEMM" else + (vllm_triton_diff + if name == "vLLM Triton" else vllm_cutlass_diff) + } + } + + print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS") + + # Calculate speedups + baseline = results["DeepGEMM"]["time_ms"] + for name, data in results.items(): + if name != "DeepGEMM": + speedup = baseline / data["time_ms"] + print(f"DeepGEMM is {speedup:.2f}x " + f"{'faster' if speedup > 1 else 'slower'} than {name}") + + vllm_triton_time = results["vLLM Triton"]["time_ms"] + vllm_cutlass_time = results["vLLM CUTLASS"]["time_ms"] + cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time + print( + f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " + f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton") + + return results + + +def run_benchmarks(): + """Run benchmarks for a set of common shapes.""" + print("===== STARTING FP8 GEMM BENCHMARK =====") + + # Make sure we're using the GPU + if not torch.cuda.is_available(): + print("CUDA not available! Tests require GPU.") + return + + print(f"Using device: {torch.cuda.get_device_name()}") + + # Enable TF32 for better performance + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Set seeds for reproducibility + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Define benchmark shapes (m, n, k) + # Common matrix shapes from LLM inference + shapes = [ + # Batch sizes x hidden dim, output dim, hidden dim + (8, 4096, 7168), # Small batch + (8, 7168, 18432), # Small batch MLP up-proj + (8, 18432, 7168), # Small batch MLP down-proj + (128, 4096, 7168), # Typical batch + (128, 7168, 18432), # MLP up-projection + (128, 18432, 7168), # MLP down-projection + (1024, 4096, 7168), # Larger batch + (1024, 18432, 7168), # Larger batch with MLP down-proj + (2048, 4096, 7168), # Very large batch + ] + + all_results = {} + for m, n, k in shapes: + shape_key = f"m{m}_n{n}_k{k}" + all_results[shape_key] = benchmark_shape(m, n, k) + + print("\n===== BENCHMARK SUMMARY =====") + print("Matrix multiplication: C[m,n] = A[m,k] @ B[n,k].T") + print("\nAverage speedups:") + + # Calculate average speedups across all shapes + speedups = { + "DeepGEMM vs vLLM Triton": [], + "DeepGEMM vs vLLM CUTLASS": [], + "vLLM CUTLASS vs vLLM Triton": [] + } + + for shape_key, results in all_results.items(): + deepgemm_time = results["DeepGEMM"]["time_ms"] + vllm_triton_time = results["vLLM Triton"]["time_ms"] + vllm_cutlass_time = results["vLLM CUTLASS"]["time_ms"] + + speedups["DeepGEMM vs vLLM Triton"].append(vllm_triton_time / + deepgemm_time) + speedups["DeepGEMM vs vLLM CUTLASS"].append(vllm_cutlass_time / + deepgemm_time) + speedups["vLLM CUTLASS vs vLLM Triton"].append(vllm_triton_time / + vllm_cutlass_time) + + for comparison, values in speedups.items(): + avg_speedup = sum(values) / len(values) + print(f"{comparison}: {avg_speedup:.2f}x " + f"{'faster' if avg_speedup > 1 else 'slower'}") + + print("\nAverage TFLOPS:") + implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] + for impl in implementations: + avg_tflops = sum( + results[impl]["tflops"] + for results in all_results.values()) / len(all_results) + print(f"{impl}: {avg_tflops:.2f} TFLOPS") + + print("\nAverage accuracy difference vs reference:") + for impl in implementations: + avg_diff = sum(results[impl]["diff"]["Reference"] + for results in all_results.values()) / len(all_results) + print(f"{impl}: {avg_diff:.6f}") + + +if __name__ == "__main__": + run_benchmarks() From 83bf65a2827497076e4d7d0d384d7ece35b3ab42 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 26 Feb 2025 19:47:43 +0000 Subject: [PATCH 2/7] Fix DeepGEMM compare Signed-off-by: mgoin --- benchmarks/kernels/deepgemm/README.md | 118 +++++++++--------- .../benchmark_fp8_block_dense_gemm.py | 4 +- 2 files changed, 60 insertions(+), 62 deletions(-) diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md index 02c5acd0be19..b02b006af813 100644 --- a/benchmarks/kernels/deepgemm/README.md +++ b/benchmarks/kernels/deepgemm/README.md @@ -17,54 +17,54 @@ uv pip install -e DeepGEMM ``` python benchmark_fp8_block_dense_gemm.py -INFO 02-26 19:12:16 [__init__.py:207] Automatically detected platform cuda. +INFO 02-26 19:45:44 [__init__.py:207] Automatically detected platform cuda. ===== STARTING FP8 GEMM BENCHMARK ===== Using device: NVIDIA H100 80GB HBM3 === Benchmarking shape: m=8, n=4096, k=7168 === Running correctness check... -WARNING 02-26 19:12:19 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +WARNING 02-26 19:45:47 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json DeepGEMM vs Reference difference: 0.000689 vLLM Triton vs Reference difference: 0.000691 vLLM CUTLASS vs Reference difference: 0.000691 vLLM Triton vs DeepGEMM difference: 0.000011 vLLM CUTLASS vs DeepGEMM difference: 0.000011 -DeepGEMM: 0.129 ms, 3.64 TFLOPS -vLLM Triton: 0.074 ms, 6.35 TFLOPS +DeepGEMM: 0.111 ms, 4.25 TFLOPS +vLLM Triton: 0.074 ms, 6.39 TFLOPS vLLM CUTLASS: 0.034 ms, 13.71 TFLOPS -DeepGEMM is 1.74x faster than vLLM Triton -DeepGEMM is 3.76x faster than vLLM CUTLASS -vLLM CUTLASS is 2.16x faster than vLLM Triton +DeepGEMM is 0.66x slower than vLLM Triton +DeepGEMM is 0.31x slower than vLLM CUTLASS +vLLM CUTLASS is 2.15x faster than vLLM Triton === Benchmarking shape: m=8, n=7168, k=18432 === Running correctness check... -INFO 02-26 19:12:19 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. +INFO 02-26 19:45:47 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. DeepGEMM vs Reference difference: 0.000680 vLLM Triton vs Reference difference: 0.000680 vLLM CUTLASS vs Reference difference: 0.000680 vLLM Triton vs DeepGEMM difference: 0.000010 vLLM CUTLASS vs DeepGEMM difference: 0.000010 -DeepGEMM: 0.114 ms, 18.48 TFLOPS -vLLM Triton: 0.091 ms, 23.14 TFLOPS -vLLM CUTLASS: 0.082 ms, 25.86 TFLOPS -DeepGEMM is 1.25x faster than vLLM Triton -DeepGEMM is 1.40x faster than vLLM CUTLASS -vLLM CUTLASS is 1.12x faster than vLLM Triton +DeepGEMM: 0.112 ms, 18.83 TFLOPS +vLLM Triton: 0.092 ms, 22.86 TFLOPS +vLLM CUTLASS: 0.081 ms, 26.15 TFLOPS +DeepGEMM is 0.82x slower than vLLM Triton +DeepGEMM is 0.72x slower than vLLM CUTLASS +vLLM CUTLASS is 1.14x faster than vLLM Triton === Benchmarking shape: m=8, n=18432, k=7168 === Running correctness check... -WARNING 02-26 19:12:19 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +WARNING 02-26 19:45:47 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json DeepGEMM vs Reference difference: 0.000682 vLLM Triton vs Reference difference: 0.000682 vLLM CUTLASS vs Reference difference: 0.000682 vLLM Triton vs DeepGEMM difference: 0.000005 vLLM CUTLASS vs DeepGEMM difference: 0.000005 -DeepGEMM: 0.113 ms, 18.68 TFLOPS -vLLM Triton: 0.117 ms, 18.03 TFLOPS -vLLM CUTLASS: 0.082 ms, 25.76 TFLOPS -DeepGEMM is 0.97x slower than vLLM Triton -DeepGEMM is 1.38x faster than vLLM CUTLASS -vLLM CUTLASS is 1.43x faster than vLLM Triton +DeepGEMM: 0.109 ms, 19.35 TFLOPS +vLLM Triton: 0.117 ms, 18.06 TFLOPS +vLLM CUTLASS: 0.081 ms, 26.21 TFLOPS +DeepGEMM is 1.07x faster than vLLM Triton +DeepGEMM is 0.74x slower than vLLM CUTLASS +vLLM CUTLASS is 1.45x faster than vLLM Triton === Benchmarking shape: m=128, n=4096, k=7168 === Running correctness check... @@ -73,12 +73,12 @@ vLLM Triton vs Reference difference: 0.000682 vLLM CUTLASS vs Reference difference: 0.000682 vLLM Triton vs DeepGEMM difference: 0.000007 vLLM CUTLASS vs DeepGEMM difference: 0.000007 -DeepGEMM: 0.114 ms, 65.79 TFLOPS +DeepGEMM: 0.109 ms, 68.76 TFLOPS vLLM Triton: 0.091 ms, 82.65 TFLOPS -vLLM CUTLASS: 0.039 ms, 191.25 TFLOPS -DeepGEMM is 1.26x faster than vLLM Triton -DeepGEMM is 2.91x faster than vLLM CUTLASS -vLLM CUTLASS is 2.31x faster than vLLM Triton +vLLM CUTLASS: 0.039 ms, 190.49 TFLOPS +DeepGEMM is 0.83x slower than vLLM Triton +DeepGEMM is 0.36x slower than vLLM CUTLASS +vLLM CUTLASS is 2.30x faster than vLLM Triton === Benchmarking shape: m=128, n=7168, k=18432 === Running correctness check... @@ -87,12 +87,12 @@ vLLM Triton vs Reference difference: 0.000683 vLLM CUTLASS vs Reference difference: 0.000683 vLLM Triton vs DeepGEMM difference: 0.000008 vLLM CUTLASS vs DeepGEMM difference: 0.000008 -DeepGEMM: 0.115 ms, 293.95 TFLOPS -vLLM Triton: 0.143 ms, 236.69 TFLOPS -vLLM CUTLASS: 0.093 ms, 363.23 TFLOPS -DeepGEMM is 0.81x slower than vLLM Triton -DeepGEMM is 1.24x faster than vLLM CUTLASS -vLLM CUTLASS is 1.53x faster than vLLM Triton +DeepGEMM: 0.115 ms, 294.42 TFLOPS +vLLM Triton: 0.142 ms, 237.38 TFLOPS +vLLM CUTLASS: 0.093 ms, 361.90 TFLOPS +DeepGEMM is 1.24x faster than vLLM Triton +DeepGEMM is 0.81x slower than vLLM CUTLASS +vLLM CUTLASS is 1.52x faster than vLLM Triton === Benchmarking shape: m=128, n=18432, k=7168 === Running correctness check... @@ -101,12 +101,12 @@ vLLM Triton vs Reference difference: 0.000684 vLLM CUTLASS vs Reference difference: 0.000684 vLLM Triton vs DeepGEMM difference: 0.000007 vLLM CUTLASS vs DeepGEMM difference: 0.000007 -DeepGEMM: 0.112 ms, 301.67 TFLOPS -vLLM Triton: 0.228 ms, 148.41 TFLOPS -vLLM CUTLASS: 0.086 ms, 395.53 TFLOPS -DeepGEMM is 0.49x slower than vLLM Triton -DeepGEMM is 1.31x faster than vLLM CUTLASS -vLLM CUTLASS is 2.67x faster than vLLM Triton +DeepGEMM: 0.110 ms, 308.47 TFLOPS +vLLM Triton: 0.228 ms, 148.56 TFLOPS +vLLM CUTLASS: 0.086 ms, 394.22 TFLOPS +DeepGEMM is 2.08x faster than vLLM Triton +DeepGEMM is 0.78x slower than vLLM CUTLASS +vLLM CUTLASS is 2.65x faster than vLLM Triton === Benchmarking shape: m=1024, n=4096, k=7168 === Running correctness check... @@ -115,12 +115,12 @@ vLLM Triton vs Reference difference: 0.000683 vLLM CUTLASS vs Reference difference: 0.000683 vLLM Triton vs DeepGEMM difference: 0.000007 vLLM CUTLASS vs DeepGEMM difference: 0.000007 -DeepGEMM: 0.171 ms, 351.94 TFLOPS -vLLM Triton: 0.241 ms, 249.66 TFLOPS -vLLM CUTLASS: 0.101 ms, 598.08 TFLOPS -DeepGEMM is 0.71x slower than vLLM Triton -DeepGEMM is 1.70x faster than vLLM CUTLASS -vLLM CUTLASS is 2.40x faster than vLLM Triton +DeepGEMM: 0.169 ms, 356.31 TFLOPS +vLLM Triton: 0.241 ms, 249.85 TFLOPS +vLLM CUTLASS: 0.101 ms, 592.45 TFLOPS +DeepGEMM is 1.43x faster than vLLM Triton +DeepGEMM is 0.60x slower than vLLM CUTLASS +vLLM CUTLASS is 2.37x faster than vLLM Triton === Benchmarking shape: m=1024, n=18432, k=7168 === Running correctness check... @@ -129,11 +129,11 @@ vLLM Triton vs Reference difference: 0.000684 vLLM CUTLASS vs Reference difference: 0.000684 vLLM Triton vs DeepGEMM difference: 0.000007 vLLM CUTLASS vs DeepGEMM difference: 0.000007 -DeepGEMM: 0.347 ms, 780.08 TFLOPS -vLLM Triton: 0.898 ms, 301.38 TFLOPS -vLLM CUTLASS: 0.331 ms, 817.56 TFLOPS -DeepGEMM is 0.39x slower than vLLM Triton -DeepGEMM is 1.05x faster than vLLM CUTLASS +DeepGEMM: 0.347 ms, 779.63 TFLOPS +vLLM Triton: 0.898 ms, 301.41 TFLOPS +vLLM CUTLASS: 0.331 ms, 818.21 TFLOPS +DeepGEMM is 2.59x faster than vLLM Triton +DeepGEMM is 0.95x slower than vLLM CUTLASS vLLM CUTLASS is 2.71x faster than vLLM Triton === Benchmarking shape: m=2048, n=4096, k=7168 === @@ -143,25 +143,25 @@ vLLM Triton vs Reference difference: 0.000683 vLLM CUTLASS vs Reference difference: 0.000683 vLLM Triton vs DeepGEMM difference: 0.000007 vLLM CUTLASS vs DeepGEMM difference: 0.000007 -DeepGEMM: 0.321 ms, 374.33 TFLOPS -vLLM Triton: 0.461 ms, 261.05 TFLOPS -vLLM CUTLASS: 0.200 ms, 601.60 TFLOPS -DeepGEMM is 0.70x slower than vLLM Triton -DeepGEMM is 1.61x faster than vLLM CUTLASS +DeepGEMM: 0.320 ms, 376.32 TFLOPS +vLLM Triton: 0.460 ms, 261.25 TFLOPS +vLLM CUTLASS: 0.200 ms, 602.18 TFLOPS +DeepGEMM is 1.44x faster than vLLM Triton +DeepGEMM is 0.62x slower than vLLM CUTLASS vLLM CUTLASS is 2.30x faster than vLLM Triton ===== BENCHMARK SUMMARY ===== Matrix multiplication: C[m,n] = A[m,k] @ B[n,k].T Average speedups: -DeepGEMM vs vLLM Triton: 1.32x faster -DeepGEMM vs vLLM CUTLASS: 0.64x slower +DeepGEMM vs vLLM Triton: 1.35x faster +DeepGEMM vs vLLM CUTLASS: 0.66x slower vLLM CUTLASS vs vLLM Triton: 2.07x faster Average TFLOPS: -DeepGEMM: 245.40 TFLOPS -vLLM Triton: 147.48 TFLOPS -vLLM CUTLASS: 336.95 TFLOPS +DeepGEMM: 247.37 TFLOPS +vLLM Triton: 147.60 TFLOPS +vLLM CUTLASS: 336.17 TFLOPS Average accuracy difference vs reference: DeepGEMM: 0.000683 diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index 7c74be91d829..a90220c34f02 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -167,7 +167,7 @@ def vllm_cutlass_gemm(): baseline = results["DeepGEMM"]["time_ms"] for name, data in results.items(): if name != "DeepGEMM": - speedup = baseline / data["time_ms"] + speedup = data["time_ms"] / baseline print(f"DeepGEMM is {speedup:.2f}x " f"{'faster' if speedup > 1 else 'slower'} than {name}") @@ -201,9 +201,7 @@ def run_benchmarks(): torch.cuda.manual_seed(42) # Define benchmark shapes (m, n, k) - # Common matrix shapes from LLM inference shapes = [ - # Batch sizes x hidden dim, output dim, hidden dim (8, 4096, 7168), # Small batch (8, 7168, 18432), # Small batch MLP up-proj (8, 18432, 7168), # Small batch MLP down-proj From 7963a9b0c7977ac06a4ee74a516233b9b6b7de60 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 26 Feb 2025 21:43:07 +0000 Subject: [PATCH 3/7] Update Signed-off-by: mgoin --- benchmarks/kernels/deepgemm/README.md | 241 ++++------ .../benchmark_fp8_block_dense_gemm.py | 66 ++- .../benchmark_fp8_block_dense_gemm_table.py | 435 ++++++++++++++++++ 3 files changed, 582 insertions(+), 160 deletions(-) create mode 100644 benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md index b02b006af813..6edc376aa251 100644 --- a/benchmarks/kernels/deepgemm/README.md +++ b/benchmarks/kernels/deepgemm/README.md @@ -16,155 +16,100 @@ uv pip install -e DeepGEMM ## Usage ``` -python benchmark_fp8_block_dense_gemm.py -INFO 02-26 19:45:44 [__init__.py:207] Automatically detected platform cuda. +python benchmark_fp8_block_dense_gemm_table.py +INFO 02-26 21:35:35 [__init__.py:207] Automatically detected platform cuda. ===== STARTING FP8 GEMM BENCHMARK ===== +PyTorch version: 2.5.1+cu124 +CUDA version: 12.4 +Triton version: 3.1.0 Using device: NVIDIA H100 80GB HBM3 -=== Benchmarking shape: m=8, n=4096, k=7168 === -Running correctness check... -WARNING 02-26 19:45:47 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json -DeepGEMM vs Reference difference: 0.000689 -vLLM Triton vs Reference difference: 0.000691 -vLLM CUTLASS vs Reference difference: 0.000691 -vLLM Triton vs DeepGEMM difference: 0.000011 -vLLM CUTLASS vs DeepGEMM difference: 0.000011 -DeepGEMM: 0.111 ms, 4.25 TFLOPS -vLLM Triton: 0.074 ms, 6.39 TFLOPS -vLLM CUTLASS: 0.034 ms, 13.71 TFLOPS -DeepGEMM is 0.66x slower than vLLM Triton -DeepGEMM is 0.31x slower than vLLM CUTLASS -vLLM CUTLASS is 2.15x faster than vLLM Triton - -=== Benchmarking shape: m=8, n=7168, k=18432 === -Running correctness check... -INFO 02-26 19:45:47 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. -DeepGEMM vs Reference difference: 0.000680 -vLLM Triton vs Reference difference: 0.000680 -vLLM CUTLASS vs Reference difference: 0.000680 -vLLM Triton vs DeepGEMM difference: 0.000010 -vLLM CUTLASS vs DeepGEMM difference: 0.000010 -DeepGEMM: 0.112 ms, 18.83 TFLOPS -vLLM Triton: 0.092 ms, 22.86 TFLOPS -vLLM CUTLASS: 0.081 ms, 26.15 TFLOPS -DeepGEMM is 0.82x slower than vLLM Triton -DeepGEMM is 0.72x slower than vLLM CUTLASS -vLLM CUTLASS is 1.14x faster than vLLM Triton - -=== Benchmarking shape: m=8, n=18432, k=7168 === -Running correctness check... -WARNING 02-26 19:45:47 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json -DeepGEMM vs Reference difference: 0.000682 -vLLM Triton vs Reference difference: 0.000682 -vLLM CUTLASS vs Reference difference: 0.000682 -vLLM Triton vs DeepGEMM difference: 0.000005 -vLLM CUTLASS vs DeepGEMM difference: 0.000005 -DeepGEMM: 0.109 ms, 19.35 TFLOPS -vLLM Triton: 0.117 ms, 18.06 TFLOPS -vLLM CUTLASS: 0.081 ms, 26.21 TFLOPS -DeepGEMM is 1.07x faster than vLLM Triton -DeepGEMM is 0.74x slower than vLLM CUTLASS -vLLM CUTLASS is 1.45x faster than vLLM Triton - -=== Benchmarking shape: m=128, n=4096, k=7168 === -Running correctness check... -DeepGEMM vs Reference difference: 0.000682 -vLLM Triton vs Reference difference: 0.000682 -vLLM CUTLASS vs Reference difference: 0.000682 -vLLM Triton vs DeepGEMM difference: 0.000007 -vLLM CUTLASS vs DeepGEMM difference: 0.000007 -DeepGEMM: 0.109 ms, 68.76 TFLOPS -vLLM Triton: 0.091 ms, 82.65 TFLOPS -vLLM CUTLASS: 0.039 ms, 190.49 TFLOPS -DeepGEMM is 0.83x slower than vLLM Triton -DeepGEMM is 0.36x slower than vLLM CUTLASS -vLLM CUTLASS is 2.30x faster than vLLM Triton - -=== Benchmarking shape: m=128, n=7168, k=18432 === -Running correctness check... -DeepGEMM vs Reference difference: 0.000683 -vLLM Triton vs Reference difference: 0.000683 -vLLM CUTLASS vs Reference difference: 0.000683 -vLLM Triton vs DeepGEMM difference: 0.000008 -vLLM CUTLASS vs DeepGEMM difference: 0.000008 -DeepGEMM: 0.115 ms, 294.42 TFLOPS -vLLM Triton: 0.142 ms, 237.38 TFLOPS -vLLM CUTLASS: 0.093 ms, 361.90 TFLOPS -DeepGEMM is 1.24x faster than vLLM Triton -DeepGEMM is 0.81x slower than vLLM CUTLASS -vLLM CUTLASS is 1.52x faster than vLLM Triton - -=== Benchmarking shape: m=128, n=18432, k=7168 === -Running correctness check... -DeepGEMM vs Reference difference: 0.000684 -vLLM Triton vs Reference difference: 0.000684 -vLLM CUTLASS vs Reference difference: 0.000684 -vLLM Triton vs DeepGEMM difference: 0.000007 -vLLM CUTLASS vs DeepGEMM difference: 0.000007 -DeepGEMM: 0.110 ms, 308.47 TFLOPS -vLLM Triton: 0.228 ms, 148.56 TFLOPS -vLLM CUTLASS: 0.086 ms, 394.22 TFLOPS -DeepGEMM is 2.08x faster than vLLM Triton -DeepGEMM is 0.78x slower than vLLM CUTLASS -vLLM CUTLASS is 2.65x faster than vLLM Triton - -=== Benchmarking shape: m=1024, n=4096, k=7168 === -Running correctness check... -DeepGEMM vs Reference difference: 0.000683 -vLLM Triton vs Reference difference: 0.000683 -vLLM CUTLASS vs Reference difference: 0.000683 -vLLM Triton vs DeepGEMM difference: 0.000007 -vLLM CUTLASS vs DeepGEMM difference: 0.000007 -DeepGEMM: 0.169 ms, 356.31 TFLOPS -vLLM Triton: 0.241 ms, 249.85 TFLOPS -vLLM CUTLASS: 0.101 ms, 592.45 TFLOPS -DeepGEMM is 1.43x faster than vLLM Triton -DeepGEMM is 0.60x slower than vLLM CUTLASS -vLLM CUTLASS is 2.37x faster than vLLM Triton - -=== Benchmarking shape: m=1024, n=18432, k=7168 === -Running correctness check... -DeepGEMM vs Reference difference: 0.000684 -vLLM Triton vs Reference difference: 0.000684 -vLLM CUTLASS vs Reference difference: 0.000684 -vLLM Triton vs DeepGEMM difference: 0.000007 -vLLM CUTLASS vs DeepGEMM difference: 0.000007 -DeepGEMM: 0.347 ms, 779.63 TFLOPS -vLLM Triton: 0.898 ms, 301.41 TFLOPS -vLLM CUTLASS: 0.331 ms, 818.21 TFLOPS -DeepGEMM is 2.59x faster than vLLM Triton -DeepGEMM is 0.95x slower than vLLM CUTLASS -vLLM CUTLASS is 2.71x faster than vLLM Triton - -=== Benchmarking shape: m=2048, n=4096, k=7168 === -Running correctness check... -DeepGEMM vs Reference difference: 0.000683 -vLLM Triton vs Reference difference: 0.000683 -vLLM CUTLASS vs Reference difference: 0.000683 -vLLM Triton vs DeepGEMM difference: 0.000007 -vLLM CUTLASS vs DeepGEMM difference: 0.000007 -DeepGEMM: 0.320 ms, 376.32 TFLOPS -vLLM Triton: 0.460 ms, 261.25 TFLOPS -vLLM CUTLASS: 0.200 ms, 602.18 TFLOPS -DeepGEMM is 1.44x faster than vLLM Triton -DeepGEMM is 0.62x slower than vLLM CUTLASS -vLLM CUTLASS is 2.30x faster than vLLM Triton - -===== BENCHMARK SUMMARY ===== -Matrix multiplication: C[m,n] = A[m,k] @ B[n,k].T - -Average speedups: -DeepGEMM vs vLLM Triton: 1.35x faster -DeepGEMM vs vLLM CUTLASS: 0.66x slower -vLLM CUTLASS vs vLLM Triton: 2.07x faster - -Average TFLOPS: -DeepGEMM: 247.37 TFLOPS -vLLM Triton: 147.60 TFLOPS -vLLM CUTLASS: 336.17 TFLOPS - -Average accuracy difference vs reference: -DeepGEMM: 0.000683 -vLLM Triton: 0.000684 -vLLM CUTLASS: 0.000684 +===== PERFORMANCE COMPARISON ===== + +DeepGEMM Implementation: ++------+-------+-------+-----------+--------+--------+ +| m | n | k | Time (μs) | TFLOPS | GB/s | ++------+-------+-------+-----------+--------+--------+ +| 8 | 4096 | 7168 | 85.1 | 5.5 | 346.3 | +| 8 | 7168 | 18432 | 83.9 | 25.2 | 1577.3 | +| 8 | 18432 | 7168 | 84.1 | 25.1 | 1576.0 | +| 64 | 24576 | 1536 | 86.1 | 56.1 | 476.0 | +| 64 | 32768 | 512 | 84.0 | 25.6 | 250.1 | +| 64 | 7168 | 16384 | 120.3 | 124.9 | 992.5 | +| 64 | 4096 | 7168 | 84.5 | 44.5 | 359.3 | +| 128 | 4096 | 7168 | 85.0 | 88.4 | 368.5 | +| 128 | 7168 | 18432 | 88.3 | 383.0 | 1543.5 | +| 128 | 18432 | 7168 | 86.4 | 391.4 | 1594.3 | +| 1024 | 4096 | 7168 | 91.7 | 655.5 | 491.5 | +| 1024 | 18432 | 7168 | 283.3 | 955.0 | 625.4 | +| 2048 | 4096 | 7168 | 177.6 | 677.1 | 342.4 | +| 4096 | 4096 | 7168 | 338.9 | 709.6 | 272.2 | ++------+-------+-------+-----------+--------+--------+ + +vLLM Triton Implementation: ++------+-------+-------+-----------+--------+--------+--------------+ +| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | ++------+-------+-------+-----------+--------+--------+--------------+ +| 8 | 4096 | 7168 | 74.4 | 6.3 | 396.4 | 1.14x faster | +| 8 | 7168 | 18432 | 89.6 | 23.6 | 1476.7 | 0.94x slower | +| 8 | 18432 | 7168 | 116.5 | 18.1 | 1137.3 | 0.72x slower | +| 64 | 24576 | 1536 | 37.2 | 129.9 | 1101.8 | 2.31x faster | +| 64 | 32768 | 512 | 38.7 | 55.5 | 542.4 | 2.17x faster | +| 64 | 7168 | 16384 | 86.7 | 173.3 | 1376.5 | 1.39x faster | +| 64 | 4096 | 7168 | 76.9 | 48.8 | 394.4 | 1.10x faster | +| 128 | 4096 | 7168 | 89.2 | 84.2 | 351.0 | 0.95x slower | +| 128 | 7168 | 18432 | 142.9 | 236.8 | 954.2 | 0.62x slower | +| 128 | 18432 | 7168 | 227.5 | 148.7 | 605.5 | 0.38x slower | +| 1024 | 4096 | 7168 | 240.7 | 249.8 | 187.3 | 0.38x slower | +| 1024 | 18432 | 7168 | 901.9 | 300.0 | 196.5 | 0.31x slower | +| 2048 | 4096 | 7168 | 462.6 | 260.0 | 131.5 | 0.38x slower | +| 4096 | 4096 | 7168 | 901.6 | 266.8 | 102.3 | 0.38x slower | ++------+-------+-------+-----------+--------+--------+--------------+ + +vLLM CUTLASS Implementation: ++------+-------+-------+-----------+--------+--------+--------------+--------------+ +| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | vs Triton | ++------+-------+-------+-----------+--------+--------+--------------+--------------+ +| 8 | 4096 | 7168 | 33.9 | 13.9 | 869.7 | 2.51x faster | 2.19x faster | +| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.7 | 1.06x faster | 1.14x faster | +| 8 | 18432 | 7168 | 80.3 | 26.3 | 1649.8 | 1.05x faster | 1.45x faster | +| 64 | 24576 | 1536 | 28.3 | 170.9 | 1449.8 | 3.05x faster | 1.32x faster | +| 64 | 32768 | 512 | 27.8 | 77.2 | 754.8 | 3.02x faster | 1.39x faster | +| 64 | 7168 | 16384 | 78.5 | 191.6 | 1522.1 | 1.53x faster | 1.11x faster | +| 64 | 4096 | 7168 | 36.4 | 103.2 | 833.4 | 2.32x faster | 2.11x faster | +| 128 | 4096 | 7168 | 39.1 | 192.3 | 801.4 | 2.17x faster | 2.28x faster | +| 128 | 7168 | 18432 | 92.9 | 364.0 | 1467.1 | 0.95x slower | 1.54x faster | +| 128 | 18432 | 7168 | 85.6 | 395.1 | 1609.2 | 1.01x faster | 2.66x faster | +| 1024 | 4096 | 7168 | 100.6 | 597.7 | 448.2 | 0.91x slower | 2.39x faster | +| 1024 | 18432 | 7168 | 329.8 | 820.5 | 537.4 | 0.86x slower | 2.73x faster | +| 2048 | 4096 | 7168 | 198.7 | 605.1 | 306.0 | 0.89x slower | 2.33x faster | +| 4096 | 4096 | 7168 | 393.0 | 612.0 | 234.8 | 0.86x slower | 2.29x faster | ++------+-------+-------+-----------+--------+--------+--------------+--------------+ + +===== AVERAGE PERFORMANCE ===== ++----------------+------------+----------+---------------+ +| Implementation | Avg TFLOPS | Avg GB/s | Avg Time (ms) | ++----------------+------------+----------+---------------+ +| DeepGEMM | 297.65 | 772.53 | 0.13 | +| vLLM Triton | 142.98 | 639.56 | 0.25 | +| vLLM CUTLASS | 299.75 | 1011.51 | 0.11 | ++----------------+------------+----------+---------------+ + +===== AVERAGE SPEEDUPS ===== ++-----------------------------+--------------+ +| Comparison | Speedup | ++-----------------------------+--------------+ +| DeepGEMM vs vLLM Triton | 1.59x faster | +| DeepGEMM vs vLLM CUTLASS | 0.79x slower | +| vLLM CUTLASS vs vLLM Triton | 1.92x faster | ++-----------------------------+--------------+ + +===== ACCURACY COMPARISON ===== ++----------------+-----------------------+ +| Implementation | Avg Diff vs Reference | ++----------------+-----------------------+ +| DeepGEMM | 0.000685 | +| vLLM Triton | 0.000685 | +| vLLM CUTLASS | 0.000685 | ++----------------+-----------------------+ ``` diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index a90220c34f02..bebede20f6ec 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +# fmt: off +# ruff: noqa: E501 import time from typing import Dict, Tuple @@ -144,12 +146,13 @@ def vllm_cutlass_gemm(): # Calculate timing and TFLOPS avg_time_ms = (end - start) / repeat * 1000 - flops = 2 * m * n * k # multiply-adds - tflops = flops / (avg_time_ms * 1e-3) / 1e12 + tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12 + mem_bw = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3) results[name] = { "time_ms": avg_time_ms, "tflops": tflops, + "mem_bw": mem_bw, "diff": { "DeepGEMM": deepgemm_diff if name == "DeepGEMM" else calc_diff( @@ -161,7 +164,9 @@ def vllm_cutlass_gemm(): } } - print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS") + print( + f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {mem_bw:.2f} GB/s" + ) # Calculate speedups baseline = results["DeepGEMM"]["time_ms"] @@ -202,17 +207,46 @@ def run_benchmarks(): # Define benchmark shapes (m, n, k) shapes = [ - (8, 4096, 7168), # Small batch - (8, 7168, 18432), # Small batch MLP up-proj - (8, 18432, 7168), # Small batch MLP down-proj - (128, 4096, 7168), # Typical batch - (128, 7168, 18432), # MLP up-projection - (128, 18432, 7168), # MLP down-projection - (1024, 4096, 7168), # Larger batch - (1024, 18432, 7168), # Larger batch with MLP down-proj - (2048, 4096, 7168), # Very large batch + (8, 4096, 7168), + (8, 7168, 18432), + (8, 18432, 7168), + (64, 24576, 1536), + (64, 32768, 512), + (64, 7168, 16384), + (64, 4096, 7168), + (128, 4096, 7168), + (128, 7168, 18432), + (128, 18432, 7168), + (1024, 4096, 7168), + (1024, 18432, 7168), + (2048, 4096, 7168), ] + # # Taken from + # # https://github.com/deepseek-ai/DeepGEMM?tab=readme-ov-file#normal-gemms-for-dense-models + # shapes = [ + # # (64, 2112, 7168), # Unsupported by CUTLASS + # (64, 24576, 1536), + # (64, 32768, 512), + # (64, 7168, 16384), + # (64, 4096, 7168), + # # (64, 7168, 2048), # Unsupported by DeepGEMM + + # # (128, 2112, 7168), # Unsupported by CUTLASS + # # (128, 24576, 1536), # Unsupported by DeepGEMM + # (128, 32768, 512), + # (128, 7168, 16384), + # (128, 4096, 7168), + # # (128, 7168, 2048), # Unsupported by DeepGEMM + + # # (4096, 2112, 7168), # Unsupported by CUTLASS + # (4096, 24576, 1536), + # (4096, 32768, 512), + # (4096, 7168, 16384), + # (4096, 4096, 7168), + # # (4096, 7168, 2048), # Unsupported by DeepGEMM + # ] + all_results = {} for m, n, k in shapes: shape_key = f"m{m}_n{n}_k{k}" @@ -254,6 +288,14 @@ def run_benchmarks(): for results in all_results.values()) / len(all_results) print(f"{impl}: {avg_tflops:.2f} TFLOPS") + print("\nAverage GB/s:") + implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] + for impl in implementations: + avg_mem_bw = sum( + results[impl]["mem_bw"] + for results in all_results.values()) / len(all_results) + print(f"{impl}: {avg_mem_bw:.2f} GB/s") + print("\nAverage accuracy difference vs reference:") for impl in implementations: avg_diff = sum(results[impl]["diff"]["Reference"] diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py new file mode 100644 index 000000000000..39e85476b3c1 --- /dev/null +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: Apache-2.0 +# fmt: off +# ruff: noqa: E501 +import time +from typing import Dict, Tuple + +# Import DeepGEMM functions +import deep_gemm +import torch +import triton +from deep_gemm import calc_diff, cell_div, get_col_major_tma_aligned_tensor + +# Import vLLM functions +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, w8a8_block_fp8_matmul) + + +# Copied from +# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9 +def per_token_cast_to_fp8( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert tensor to FP8 format with per-token scaling.""" + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to( + torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + +# Copied from +# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17 +def per_block_cast_to_fp8( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert tensor to FP8 format with per-block scaling.""" + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( + x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + +def benchmark_shape(m: int, + n: int, + k: int, + warmup: int = 10, + repeat: int = 1000, + verbose: bool = False) -> Dict: + """Benchmark all implementations for a specific (m, n, k) shape.""" + if verbose: + print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") + + # Create test tensors + A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + + # Reference result in BF16 + torch.cuda.synchronize() + C_ref = A @ B.t() + + # Pre-quantize B for all implementations + # (weights can be pre-quantized offline) + B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B) + B_vllm, B_scale_vllm = per_block_cast_to_fp8(B) + + # Block size configuration + block_size = [128, 128] + + # === DeepGEMM Implementation === + def deepgemm_gemm(): + # A quantization is inside the loop as it depends on activations + # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) + A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( + A, block_size[1]) + A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) + C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_aligned), + (B_deepgemm, B_scale_deepgemm), + C_deepgemm) + return C_deepgemm + + # === vLLM Triton Implementation === + def vllm_triton_gemm(): + # A quantization is inside the loop as it depends on activations + A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) + return w8a8_block_fp8_matmul(A_vllm, + B_vllm, + A_scale_vllm, + B_scale_vllm, + block_size, + output_dtype=torch.bfloat16) + + # === vLLM CUTLASS Implementation === + def vllm_cutlass_gemm(): + # A quantization is inside the loop as it depends on activations + A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( + A, block_size[1], column_major_scales=True) + return ops.cutlass_scaled_mm(A_vllm_cutlass, + B_vllm.T, + scale_a=A_scale_vllm_cutlass, + scale_b=B_scale_vllm.T, + out_dtype=torch.bfloat16) + + # Run correctness check first + if verbose: + print("Running correctness check...") + C_deepgemm = deepgemm_gemm() + C_vllm_triton = vllm_triton_gemm() + C_vllm_cutlass = vllm_cutlass_gemm() + + deepgemm_diff = calc_diff(C_deepgemm, C_ref) + vllm_triton_diff = calc_diff(C_vllm_triton, C_ref) + vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref) + + if verbose: + print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") + print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") + print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") + print("vLLM Triton vs DeepGEMM difference: " + f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") + print("vLLM CUTLASS vs DeepGEMM difference: " + f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") + + # Benchmark implementations + implementations = { + "DeepGEMM": deepgemm_gemm, + "vLLM Triton": vllm_triton_gemm, + "vLLM CUTLASS": vllm_cutlass_gemm + } + + benchmark_results = { + "shape": { + "m": m, + "n": n, + "k": k + }, + "implementations": {} + } + + for name, func in implementations.items(): + # Warmup + for _ in range(warmup): + func() + torch.cuda.synchronize() + + # Timing loop + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + func() + torch.cuda.synchronize() + end = time.time() + + # Calculate timing and TFLOPS + avg_time_ms = (end - start) / repeat * 1000 + avg_time_us = avg_time_ms * 1000 + tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12 + gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3) + + benchmark_results["implementations"][name] = { + "time_ms": avg_time_ms, + "time_us": avg_time_us, + "tflops": tflops, + "gb_s": gb_s, + "diff": { + "DeepGEMM": + 0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm), + "Reference": + deepgemm_diff if name == "DeepGEMM" else + (vllm_triton_diff + if name == "vLLM Triton" else vllm_cutlass_diff) + } + } + + if verbose: + print( + f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s" + ) + + # Calculate speedups + baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] + for name, data in benchmark_results["implementations"].items(): + if name != "DeepGEMM": + speedup = baseline / data["time_ms"] + benchmark_results["implementations"][name][ + "speedup_vs_deepgemm"] = speedup + if verbose: + print(f"DeepGEMM is {1/speedup:.2f}x " + f"{'faster' if 1/speedup > 1 else 'slower'} than {name}") + + vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][ + "time_ms"] + vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][ + "time_ms"] + cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time + benchmark_results["implementations"]["vLLM CUTLASS"][ + "speedup_vs_triton"] = cutlass_vs_triton + if verbose: + print( + f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " + f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton" + ) + + return benchmark_results + + +def format_table_row(values, widths): + """Format a row with specified column widths.""" + return "| " + " | ".join(f"{val:{w}}" + for val, w in zip(values, widths)) + " |" + + +def print_table(headers, rows, title=None): + """Print a table with headers and rows.""" + if title: + print(f"\n{title}") + + # Calculate column widths based on headers and data + widths = [ + max(len(str(h)), max(len(str(row[i])) for row in rows)) + for i, h in enumerate(headers) + ] + + # Create separator line + separator = "+-" + "-+-".join("-" * w for w in widths) + "-+" + + # Print table + print(separator) + print(format_table_row(headers, widths)) + print(separator) + for row in rows: + print(format_table_row(row, widths)) + print(separator) + + +def format_speedup(value): + """Format speedup value with indicator if it's faster or slower.""" + return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}" + + +def run_benchmarks(verbose: bool = False): + """Run benchmarks for a set of common shapes.""" + print("===== STARTING FP8 GEMM BENCHMARK =====") + + # Make sure we're using the GPU + if not torch.cuda.is_available(): + print("CUDA not available! Tests require GPU.") + return + + # Print system information + print(f"PyTorch version: {torch.__version__}") + print(f"CUDA version: {torch.version.cuda}") + print(f"Triton version: {triton.__version__}") + print(f"Using device: {torch.cuda.get_device_name()}") + + # Enable TF32 for better performance + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Set seeds for reproducibility + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Define benchmark shapes (m, n, k) + shapes = [ + (8, 4096, 7168), + (8, 7168, 18432), + (8, 18432, 7168), + (64, 24576, 1536), + (64, 32768, 512), + (64, 7168, 16384), + (64, 4096, 7168), + (128, 4096, 7168), + (128, 7168, 18432), + (128, 18432, 7168), + (1024, 4096, 7168), + (1024, 18432, 7168), + (2048, 4096, 7168), + (4096, 4096, 7168), + ] + + all_results = [] + for m, n, k in shapes: + result = benchmark_shape(m, n, k, verbose=verbose) + all_results.append(result) + + # Print results in a nicely formatted table + print("\n===== PERFORMANCE COMPARISON =====") + + # Print DeepGEMM table + deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"] + deepgemm_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["DeepGEMM"] + deepgemm_rows.append([ + shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}" + ]) + + print_table(deepgemm_headers, + deepgemm_rows, + title="DeepGEMM Implementation:") + + # Print vLLM Triton table + triton_headers = [ + "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM" + ] + triton_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["vLLM Triton"] + speedup = impl_data.get("speedup_vs_deepgemm", 1.0) + triton_rows.append([ + shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", + format_speedup(speedup) + ]) + + print_table(triton_headers, + triton_rows, + title="vLLM Triton Implementation:") + + # Print vLLM CUTLASS table + cutlass_headers = [ + "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM", + "vs Triton" + ] + cutlass_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["vLLM CUTLASS"] + vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) + vs_triton = impl_data.get("speedup_vs_triton", 1.0) + cutlass_rows.append([ + shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", + format_speedup(vs_deepgemm), + format_speedup(vs_triton) + ]) + + print_table(cutlass_headers, + cutlass_rows, + title="vLLM CUTLASS Implementation:") + + # Calculate and print averages + print("\n===== AVERAGE PERFORMANCE =====") + + implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] + avg_metrics = { + impl: { + "tflops": 0, + "gb_s": 0, + "time_ms": 0 + } + for impl in implementations + } + + for result in all_results: + for impl in implementations: + impl_data = result["implementations"][impl] + avg_metrics[impl]["tflops"] += impl_data["tflops"] + avg_metrics[impl]["gb_s"] += impl_data["gb_s"] + avg_metrics[impl]["time_ms"] += impl_data["time_ms"] + + num_shapes = len(all_results) + avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"] + avg_rows = [] + + for impl in implementations: + avg_tflops = avg_metrics[impl]["tflops"] / num_shapes + avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes + avg_time = avg_metrics[impl]["time_ms"] / num_shapes + avg_rows.append([ + impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}" + ]) + + print_table(avg_headers, avg_rows) + + # Calculate average speedups + avg_speedups = { + "DeepGEMM vs vLLM Triton": 0, + "DeepGEMM vs vLLM CUTLASS": 0, + "vLLM CUTLASS vs vLLM Triton": 0 + } + + for result in all_results: + deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] + vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] + vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][ + "time_ms"] + + avg_speedups[ + "DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time + avg_speedups[ + "DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time + avg_speedups[ + "vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time + + print("\n===== AVERAGE SPEEDUPS =====") + speedup_headers = ["Comparison", "Speedup"] + speedup_rows = [] + for comparison, total in avg_speedups.items(): + avg_speedup = total / num_shapes + status = "faster" if avg_speedup > 1 else "slower" + speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"]) + + print_table(speedup_headers, speedup_rows) + + # Average accuracy comparison + print("\n===== ACCURACY COMPARISON =====") + avg_diff = {impl: 0 for impl in implementations} + + for result in all_results: + for impl in implementations: + avg_diff[impl] += result["implementations"][impl]["diff"][ + "Reference"] + + diff_headers = ["Implementation", "Avg Diff vs Reference"] + diff_rows = [] + for impl in implementations: + diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"]) + + print_table(diff_headers, diff_rows) + + +if __name__ == "__main__": + run_benchmarks(verbose=False) From c95c6adac7b9ce3e3c80529a61c07de2896adef9 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 26 Feb 2025 21:43:54 +0000 Subject: [PATCH 4/7] Update Signed-off-by: mgoin --- benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index bebede20f6ec..e6abf91715e9 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -75,7 +75,9 @@ def benchmark_shape(m: int, # === DeepGEMM Implementation === def deepgemm_gemm(): # A quantization is inside the loop as it depends on activations - A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) + # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) + A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( + A, block_size[1]) A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_aligned), From 1a6f96e9cf1e1cefe89585a1d5103136e5ea290f Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 26 Feb 2025 21:57:12 +0000 Subject: [PATCH 5/7] More update! Signed-off-by: mgoin --- benchmarks/kernels/deepgemm/README.md | 116 ++++++++++-------- .../benchmark_fp8_block_dense_gemm_table.py | 10 +- 2 files changed, 70 insertions(+), 56 deletions(-) diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md index 6edc376aa251..e2120f53f3d6 100644 --- a/benchmarks/kernels/deepgemm/README.md +++ b/benchmarks/kernels/deepgemm/README.md @@ -17,12 +17,18 @@ uv pip install -e DeepGEMM ``` python benchmark_fp8_block_dense_gemm_table.py -INFO 02-26 21:35:35 [__init__.py:207] Automatically detected platform cuda. +INFO 02-26 21:55:13 [__init__.py:207] Automatically detected platform cuda. ===== STARTING FP8 GEMM BENCHMARK ===== PyTorch version: 2.5.1+cu124 CUDA version: 12.4 Triton version: 3.1.0 Using device: NVIDIA H100 80GB HBM3 +WARNING 02-26 21:55:15 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +INFO 02-26 21:55:15 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. +WARNING 02-26 21:55:16 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +WARNING 02-26 21:55:17 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. +INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel. ===== PERFORMANCE COMPARISON ===== @@ -30,86 +36,92 @@ DeepGEMM Implementation: +------+-------+-------+-----------+--------+--------+ | m | n | k | Time (μs) | TFLOPS | GB/s | +------+-------+-------+-----------+--------+--------+ -| 8 | 4096 | 7168 | 85.1 | 5.5 | 346.3 | -| 8 | 7168 | 18432 | 83.9 | 25.2 | 1577.3 | -| 8 | 18432 | 7168 | 84.1 | 25.1 | 1576.0 | -| 64 | 24576 | 1536 | 86.1 | 56.1 | 476.0 | -| 64 | 32768 | 512 | 84.0 | 25.6 | 250.1 | -| 64 | 7168 | 16384 | 120.3 | 124.9 | 992.5 | -| 64 | 4096 | 7168 | 84.5 | 44.5 | 359.3 | -| 128 | 4096 | 7168 | 85.0 | 88.4 | 368.5 | -| 128 | 7168 | 18432 | 88.3 | 383.0 | 1543.5 | -| 128 | 18432 | 7168 | 86.4 | 391.4 | 1594.3 | -| 1024 | 4096 | 7168 | 91.7 | 655.5 | 491.5 | -| 1024 | 18432 | 7168 | 283.3 | 955.0 | 625.4 | -| 2048 | 4096 | 7168 | 177.6 | 677.1 | 342.4 | -| 4096 | 4096 | 7168 | 338.9 | 709.6 | 272.2 | +| 8 | 4096 | 7168 | 102.9 | 4.6 | 286.4 | +| 8 | 7168 | 18432 | 70.8 | 29.8 | 1868.8 | +| 8 | 18432 | 7168 | 69.3 | 30.5 | 1911.8 | +| 64 | 4096 | 7168 | 69.1 | 54.4 | 439.0 | +| 64 | 7168 | 18432 | 69.4 | 243.6 | 1933.6 | +| 64 | 18432 | 7168 | 70.4 | 240.3 | 1917.2 | +| 64 | 24576 | 1536 | 70.1 | 68.9 | 584.6 | +| 64 | 32768 | 512 | 68.4 | 31.4 | 307.1 | +| 64 | 7168 | 16384 | 69.5 | 216.3 | 1718.5 | +| 128 | 4096 | 7168 | 141.1 | 53.3 | 222.1 | +| 128 | 7168 | 18432 | 71.9 | 470.5 | 1896.1 | +| 128 | 18432 | 7168 | 69.3 | 488.2 | 1988.2 | +| 1024 | 4096 | 7168 | 89.7 | 670.1 | 502.5 | +| 1024 | 18432 | 7168 | 279.0 | 969.8 | 635.2 | +| 2048 | 4096 | 7168 | 175.1 | 687.0 | 347.4 | +| 4096 | 4096 | 7168 | 335.4 | 717.0 | 275.1 | +------+-------+-------+-----------+--------+--------+ vLLM Triton Implementation: +------+-------+-------+-----------+--------+--------+--------------+ | m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | +------+-------+-------+-----------+--------+--------+--------------+ -| 8 | 4096 | 7168 | 74.4 | 6.3 | 396.4 | 1.14x faster | -| 8 | 7168 | 18432 | 89.6 | 23.6 | 1476.7 | 0.94x slower | -| 8 | 18432 | 7168 | 116.5 | 18.1 | 1137.3 | 0.72x slower | -| 64 | 24576 | 1536 | 37.2 | 129.9 | 1101.8 | 2.31x faster | -| 64 | 32768 | 512 | 38.7 | 55.5 | 542.4 | 2.17x faster | -| 64 | 7168 | 16384 | 86.7 | 173.3 | 1376.5 | 1.39x faster | -| 64 | 4096 | 7168 | 76.9 | 48.8 | 394.4 | 1.10x faster | -| 128 | 4096 | 7168 | 89.2 | 84.2 | 351.0 | 0.95x slower | -| 128 | 7168 | 18432 | 142.9 | 236.8 | 954.2 | 0.62x slower | -| 128 | 18432 | 7168 | 227.5 | 148.7 | 605.5 | 0.38x slower | -| 1024 | 4096 | 7168 | 240.7 | 249.8 | 187.3 | 0.38x slower | -| 1024 | 18432 | 7168 | 901.9 | 300.0 | 196.5 | 0.31x slower | -| 2048 | 4096 | 7168 | 462.6 | 260.0 | 131.5 | 0.38x slower | -| 4096 | 4096 | 7168 | 901.6 | 266.8 | 102.3 | 0.38x slower | +| 8 | 4096 | 7168 | 74.0 | 6.3 | 398.2 | 1.39x faster | +| 8 | 7168 | 18432 | 89.6 | 23.6 | 1478.1 | 0.79x slower | +| 8 | 18432 | 7168 | 113.2 | 18.7 | 1170.4 | 0.61x slower | +| 64 | 4096 | 7168 | 79.4 | 47.3 | 382.2 | 0.87x slower | +| 64 | 7168 | 18432 | 98.5 | 171.7 | 1363.0 | 0.70x slower | +| 64 | 18432 | 7168 | 119.5 | 141.5 | 1129.4 | 0.59x slower | +| 64 | 24576 | 1536 | 37.6 | 128.4 | 1089.7 | 1.86x faster | +| 64 | 32768 | 512 | 38.7 | 55.5 | 542.6 | 1.77x faster | +| 64 | 7168 | 16384 | 86.1 | 174.5 | 1386.4 | 0.81x slower | +| 128 | 4096 | 7168 | 90.7 | 82.9 | 345.4 | 1.56x faster | +| 128 | 7168 | 18432 | 144.0 | 234.9 | 946.9 | 0.50x slower | +| 128 | 18432 | 7168 | 229.5 | 147.4 | 600.1 | 0.30x slower | +| 1024 | 4096 | 7168 | 242.3 | 248.2 | 186.1 | 0.37x slower | +| 1024 | 18432 | 7168 | 897.8 | 301.4 | 197.4 | 0.31x slower | +| 2048 | 4096 | 7168 | 463.0 | 259.7 | 131.4 | 0.38x slower | +| 4096 | 4096 | 7168 | 901.8 | 266.7 | 102.3 | 0.37x slower | +------+-------+-------+-----------+--------+--------+--------------+ vLLM CUTLASS Implementation: +------+-------+-------+-----------+--------+--------+--------------+--------------+ | m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | vs Triton | +------+-------+-------+-----------+--------+--------+--------------+--------------+ -| 8 | 4096 | 7168 | 33.9 | 13.9 | 869.7 | 2.51x faster | 2.19x faster | -| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.7 | 1.06x faster | 1.14x faster | -| 8 | 18432 | 7168 | 80.3 | 26.3 | 1649.8 | 1.05x faster | 1.45x faster | -| 64 | 24576 | 1536 | 28.3 | 170.9 | 1449.8 | 3.05x faster | 1.32x faster | -| 64 | 32768 | 512 | 27.8 | 77.2 | 754.8 | 3.02x faster | 1.39x faster | -| 64 | 7168 | 16384 | 78.5 | 191.6 | 1522.1 | 1.53x faster | 1.11x faster | -| 64 | 4096 | 7168 | 36.4 | 103.2 | 833.4 | 2.32x faster | 2.11x faster | -| 128 | 4096 | 7168 | 39.1 | 192.3 | 801.4 | 2.17x faster | 2.28x faster | -| 128 | 7168 | 18432 | 92.9 | 364.0 | 1467.1 | 0.95x slower | 1.54x faster | -| 128 | 18432 | 7168 | 85.6 | 395.1 | 1609.2 | 1.01x faster | 2.66x faster | -| 1024 | 4096 | 7168 | 100.6 | 597.7 | 448.2 | 0.91x slower | 2.39x faster | -| 1024 | 18432 | 7168 | 329.8 | 820.5 | 537.4 | 0.86x slower | 2.73x faster | -| 2048 | 4096 | 7168 | 198.7 | 605.1 | 306.0 | 0.89x slower | 2.33x faster | -| 4096 | 4096 | 7168 | 393.0 | 612.0 | 234.8 | 0.86x slower | 2.29x faster | +| 8 | 4096 | 7168 | 34.6 | 13.6 | 852.3 | 2.98x faster | 2.14x faster | +| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.3 | 0.90x slower | 1.13x faster | +| 8 | 18432 | 7168 | 81.2 | 26.0 | 1631.1 | 0.85x slower | 1.39x faster | +| 64 | 4096 | 7168 | 36.9 | 101.9 | 822.9 | 1.87x faster | 2.15x faster | +| 64 | 7168 | 18432 | 87.4 | 193.4 | 1535.2 | 0.79x slower | 1.13x faster | +| 64 | 18432 | 7168 | 85.0 | 199.0 | 1587.6 | 0.83x slower | 1.41x faster | +| 64 | 24576 | 1536 | 28.0 | 172.8 | 1465.8 | 2.51x faster | 1.35x faster | +| 64 | 32768 | 512 | 28.8 | 74.5 | 728.5 | 2.37x faster | 1.34x faster | +| 64 | 7168 | 16384 | 77.9 | 193.0 | 1532.8 | 0.89x slower | 1.11x faster | +| 128 | 4096 | 7168 | 39.1 | 192.4 | 802.0 | 3.61x faster | 2.32x faster | +| 128 | 7168 | 18432 | 93.7 | 360.8 | 1454.2 | 0.77x slower | 1.54x faster | +| 128 | 18432 | 7168 | 85.7 | 394.8 | 1608.0 | 0.81x slower | 2.68x faster | +| 1024 | 4096 | 7168 | 99.7 | 603.1 | 452.2 | 0.90x slower | 2.43x faster | +| 1024 | 18432 | 7168 | 331.3 | 816.7 | 534.9 | 0.84x slower | 2.71x faster | +| 2048 | 4096 | 7168 | 198.3 | 606.6 | 306.7 | 0.88x slower | 2.34x faster | +| 4096 | 4096 | 7168 | 392.2 | 613.2 | 235.3 | 0.86x slower | 2.30x faster | +------+-------+-------+-----------+--------+--------+--------------+--------------+ ===== AVERAGE PERFORMANCE ===== +----------------+------------+----------+---------------+ | Implementation | Avg TFLOPS | Avg GB/s | Avg Time (ms) | +----------------+------------+----------+---------------+ -| DeepGEMM | 297.65 | 772.53 | 0.13 | -| vLLM Triton | 142.98 | 639.56 | 0.25 | -| vLLM CUTLASS | 299.75 | 1011.51 | 0.11 | +| DeepGEMM | 310.98 | 1052.10 | 0.11 | +| vLLM Triton | 144.30 | 715.60 | 0.23 | +| vLLM CUTLASS | 286.78 | 1076.67 | 0.11 | +----------------+------------+----------+---------------+ ===== AVERAGE SPEEDUPS ===== +-----------------------------+--------------+ | Comparison | Speedup | +-----------------------------+--------------+ -| DeepGEMM vs vLLM Triton | 1.59x faster | -| DeepGEMM vs vLLM CUTLASS | 0.79x slower | -| vLLM CUTLASS vs vLLM Triton | 1.92x faster | +| DeepGEMM vs vLLM Triton | 1.71x faster | +| DeepGEMM vs vLLM CUTLASS | 0.94x slower | +| vLLM CUTLASS vs vLLM Triton | 1.84x faster | +-----------------------------+--------------+ ===== ACCURACY COMPARISON ===== +----------------+-----------------------+ | Implementation | Avg Diff vs Reference | +----------------+-----------------------+ -| DeepGEMM | 0.000685 | -| vLLM Triton | 0.000685 | -| vLLM CUTLASS | 0.000685 | +| DeepGEMM | 0.000684 | +| vLLM Triton | 0.000684 | +| vLLM CUTLASS | 0.000684 | +----------------+-----------------------+ ``` diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py index 39e85476b3c1..f0a4cecdb59b 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py @@ -8,7 +8,7 @@ import deep_gemm import torch import triton -from deep_gemm import calc_diff, cell_div, get_col_major_tma_aligned_tensor +from deep_gemm import calc_diff, cell_div # Import vLLM functions from vllm import _custom_ops as ops @@ -79,9 +79,9 @@ def deepgemm_gemm(): # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( A, block_size[1]) - A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) + # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_aligned), + deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm) return C_deepgemm @@ -273,10 +273,12 @@ def run_benchmarks(verbose: bool = False): (8, 4096, 7168), (8, 7168, 18432), (8, 18432, 7168), + (64, 4096, 7168), + (64, 7168, 18432), + (64, 18432, 7168), (64, 24576, 1536), (64, 32768, 512), (64, 7168, 16384), - (64, 4096, 7168), (128, 4096, 7168), (128, 7168, 18432), (128, 18432, 7168), From b287504bacdbd733489ad6438589676d2369398f Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 27 Feb 2025 21:57:18 +0000 Subject: [PATCH 6/7] Update by removing quantization overhead Signed-off-by: mgoin --- .../benchmark_fp8_block_dense_gemm_table.py | 48 +++++++++++++++---- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py index f0a4cecdb59b..4af961a62adf 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py @@ -8,7 +8,7 @@ import deep_gemm import torch import triton -from deep_gemm import calc_diff, cell_div +from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor # Import vLLM functions from vllm import _custom_ops as ops @@ -36,7 +36,7 @@ def per_block_cast_to_fp8( """Convert tensor to FP8 format with per-block scaling.""" assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), + x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x @@ -50,8 +50,8 @@ def per_block_cast_to_fp8( def benchmark_shape(m: int, n: int, k: int, - warmup: int = 10, - repeat: int = 1000, + warmup: int = 100, + repeat: int = 10000, verbose: bool = False) -> Dict: """Benchmark all implementations for a specific (m, n, k) shape.""" if verbose: @@ -73,14 +73,22 @@ def benchmark_shape(m: int, # Block size configuration block_size = [128, 128] + # Pre-quantize A for all implementations + A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) + A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) + C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) + A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( + A, block_size[1], column_major_scales=True) + # === DeepGEMM Implementation === def deepgemm_gemm(): # A quantization is inside the loop as it depends on activations # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) - A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( - A, block_size[1]) + # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( + # A, block_size[1]) # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm) @@ -89,7 +97,7 @@ def deepgemm_gemm(): # === vLLM Triton Implementation === def vllm_triton_gemm(): # A quantization is inside the loop as it depends on activations - A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) + # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) return w8a8_block_fp8_matmul(A_vllm, B_vllm, A_scale_vllm, @@ -100,8 +108,8 @@ def vllm_triton_gemm(): # === vLLM CUTLASS Implementation === def vllm_cutlass_gemm(): # A quantization is inside the loop as it depends on activations - A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - A, block_size[1], column_major_scales=True) + # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( + # A, block_size[1], column_major_scales=True) return ops.cutlass_scaled_mm(A_vllm_cutlass, B_vllm.T, scale_a=A_scale_vllm_cutlass, @@ -287,6 +295,26 @@ def run_benchmarks(verbose: bool = False): (2048, 4096, 7168), (4096, 4096, 7168), ] + shapes = [ + # (64, 2112, 7168), + (64, 24576, 1536), + (64, 32768, 512), + (64, 7168, 16384), + (64, 4096, 7168), + (64, 7168, 2048), + # (128, 2112, 7168), + (128, 24576, 1536), + (128, 32768, 512), + (128, 7168, 16384), + (128, 4096, 7168), + (128, 7168, 2048), + # (4096, 2112, 7168), + (4096, 24576, 1536), + (4096, 32768, 512), + (4096, 7168, 16384), + (4096, 4096, 7168), + (4096, 7168, 2048), + ] all_results = [] for m, n, k in shapes: From 07c8762741659425896ee0453dce379428b3e058 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 5 Mar 2025 02:07:31 +0000 Subject: [PATCH 7/7] Update Signed-off-by: mgoin --- benchmarks/kernels/deepgemm/README.md | 8 +- .../benchmark_fp8_block_dense_gemm.py | 391 ++++++++++----- .../benchmark_fp8_block_dense_gemm_table.py | 465 ------------------ 3 files changed, 278 insertions(+), 586 deletions(-) delete mode 100644 benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py diff --git a/benchmarks/kernels/deepgemm/README.md b/benchmarks/kernels/deepgemm/README.md index e2120f53f3d6..917e814010f8 100644 --- a/benchmarks/kernels/deepgemm/README.md +++ b/benchmarks/kernels/deepgemm/README.md @@ -6,17 +6,19 @@ Currently this just includes dense GEMMs and only works on Hopper GPUs. ## Setup -You need to install vLLM in your usual fashion, then install DeepGEMM from source: +You need to install vLLM in your usual fashion, then install DeepGEMM from source in its own directory: ``` git clone --recursive https://github.com/deepseek-ai/DeepGEMM -uv pip install -e DeepGEMM +cd DeepGEMM +python setup.py install +uv pip install -e . ``` ## Usage ``` -python benchmark_fp8_block_dense_gemm_table.py +python benchmark_fp8_block_dense_gemm.py INFO 02-26 21:55:13 [__init__.py:207] Automatically detected platform cuda. ===== STARTING FP8 GEMM BENCHMARK ===== PyTorch version: 2.5.1+cu124 diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index e6abf91715e9..7892f126e7d6 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -2,12 +2,12 @@ # fmt: off # ruff: noqa: E501 import time -from typing import Dict, Tuple # Import DeepGEMM functions import deep_gemm import torch -from deep_gemm import calc_diff, cell_div, get_col_major_tma_aligned_tensor +import triton +from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor # Import vLLM functions from vllm import _custom_ops as ops @@ -18,7 +18,7 @@ # Copied from # https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9 def per_token_cast_to_fp8( - x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Convert tensor to FP8 format with per-token scaling.""" assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape @@ -31,11 +31,11 @@ def per_token_cast_to_fp8( # Copied from # https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17 def per_block_cast_to_fp8( - x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Convert tensor to FP8 format with per-block scaling.""" assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), + x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x @@ -49,10 +49,12 @@ def per_block_cast_to_fp8( def benchmark_shape(m: int, n: int, k: int, - warmup: int = 10, - repeat: int = 1000) -> Dict[str, Dict[str, float]]: + warmup: int = 100, + repeat: int = 10000, + verbose: bool = False) -> dict: """Benchmark all implementations for a specific (m, n, k) shape.""" - print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") + if verbose: + print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") # Create test tensors A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) @@ -70,17 +72,23 @@ def benchmark_shape(m: int, # Block size configuration block_size = [128, 128] - results = {} + # Pre-quantize A for all implementations + A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) + A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) + C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) + A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( + A, block_size[1], column_major_scales=True) # === DeepGEMM Implementation === def deepgemm_gemm(): # A quantization is inside the loop as it depends on activations # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) - A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( - A, block_size[1]) - A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_aligned), + # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( + # A, block_size[1]) + # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) + # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm) return C_deepgemm @@ -88,7 +96,7 @@ def deepgemm_gemm(): # === vLLM Triton Implementation === def vllm_triton_gemm(): # A quantization is inside the loop as it depends on activations - A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) + # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) return w8a8_block_fp8_matmul(A_vllm, B_vllm, A_scale_vllm, @@ -99,8 +107,8 @@ def vllm_triton_gemm(): # === vLLM CUTLASS Implementation === def vllm_cutlass_gemm(): # A quantization is inside the loop as it depends on activations - A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - A, block_size[1], column_major_scales=True) + # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( + # A, block_size[1], column_major_scales=True) return ops.cutlass_scaled_mm(A_vllm_cutlass, B_vllm.T, scale_a=A_scale_vllm_cutlass, @@ -108,7 +116,8 @@ def vllm_cutlass_gemm(): out_dtype=torch.bfloat16) # Run correctness check first - print("Running correctness check...") + if verbose: + print("Running correctness check...") C_deepgemm = deepgemm_gemm() C_vllm_triton = vllm_triton_gemm() C_vllm_cutlass = vllm_cutlass_gemm() @@ -117,13 +126,14 @@ def vllm_cutlass_gemm(): vllm_triton_diff = calc_diff(C_vllm_triton, C_ref) vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref) - print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") - print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") - print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") - print("vLLM Triton vs DeepGEMM difference: " - f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") - print("vLLM CUTLASS vs DeepGEMM difference: " - f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") + if verbose: + print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") + print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") + print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") + print("vLLM Triton vs DeepGEMM difference: " + f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") + print("vLLM CUTLASS vs DeepGEMM difference: " + f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") # Benchmark implementations implementations = { @@ -132,6 +142,15 @@ def vllm_cutlass_gemm(): "vLLM CUTLASS": vllm_cutlass_gemm } + benchmark_results = { + "shape": { + "m": m, + "n": n, + "k": k + }, + "implementations": {} + } + for name, func in implementations.items(): # Warmup for _ in range(warmup): @@ -148,17 +167,18 @@ def vllm_cutlass_gemm(): # Calculate timing and TFLOPS avg_time_ms = (end - start) / repeat * 1000 + avg_time_us = avg_time_ms * 1000 tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12 - mem_bw = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3) + gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3) - results[name] = { + benchmark_results["implementations"][name] = { "time_ms": avg_time_ms, + "time_us": avg_time_us, "tflops": tflops, - "mem_bw": mem_bw, + "gb_s": gb_s, "diff": { "DeepGEMM": - deepgemm_diff if name == "DeepGEMM" else calc_diff( - func(), C_deepgemm), + 0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm), "Reference": deepgemm_diff if name == "DeepGEMM" else (vllm_triton_diff @@ -166,29 +186,73 @@ def vllm_cutlass_gemm(): } } - print( - f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {mem_bw:.2f} GB/s" - ) + if verbose: + print( + f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s" + ) # Calculate speedups - baseline = results["DeepGEMM"]["time_ms"] - for name, data in results.items(): + baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] + for name, data in benchmark_results["implementations"].items(): if name != "DeepGEMM": - speedup = data["time_ms"] / baseline - print(f"DeepGEMM is {speedup:.2f}x " - f"{'faster' if speedup > 1 else 'slower'} than {name}") - - vllm_triton_time = results["vLLM Triton"]["time_ms"] - vllm_cutlass_time = results["vLLM CUTLASS"]["time_ms"] + speedup = baseline / data["time_ms"] + benchmark_results["implementations"][name][ + "speedup_vs_deepgemm"] = speedup + if verbose: + print(f"DeepGEMM is {1/speedup:.2f}x " + f"{'faster' if 1/speedup > 1 else 'slower'} than {name}") + + vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][ + "time_ms"] + vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][ + "time_ms"] cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time - print( - f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " - f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton") + benchmark_results["implementations"]["vLLM CUTLASS"][ + "speedup_vs_triton"] = cutlass_vs_triton + if verbose: + print( + f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " + f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton" + ) + + return benchmark_results + - return results +def format_table_row(values, widths): + """Format a row with specified column widths.""" + return "| " + " | ".join(f"{val:{w}}" + for val, w in zip(values, widths)) + " |" -def run_benchmarks(): +def print_table(headers, rows, title=None): + """Print a table with headers and rows.""" + if title: + print(f"\n{title}") + + # Calculate column widths based on headers and data + widths = [ + max(len(str(h)), max(len(str(row[i])) for row in rows)) + for i, h in enumerate(headers) + ] + + # Create separator line + separator = "+-" + "-+-".join("-" * w for w in widths) + "-+" + + # Print table + print(separator) + print(format_table_row(headers, widths)) + print(separator) + for row in rows: + print(format_table_row(row, widths)) + print(separator) + + +def format_speedup(value): + """Format speedup value with indicator if it's faster or slower.""" + return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}" + + +def run_benchmarks(verbose: bool = False): """Run benchmarks for a set of common shapes.""" print("===== STARTING FP8 GEMM BENCHMARK =====") @@ -197,6 +261,10 @@ def run_benchmarks(): print("CUDA not available! Tests require GPU.") return + # Print system information + print(f"PyTorch version: {torch.__version__}") + print(f"CUDA version: {torch.version.cuda}") + print(f"Triton version: {triton.__version__}") print(f"Using device: {torch.cuda.get_device_name()}") # Enable TF32 for better performance @@ -212,98 +280,185 @@ def run_benchmarks(): (8, 4096, 7168), (8, 7168, 18432), (8, 18432, 7168), + (64, 4096, 7168), + (64, 7168, 18432), + (64, 18432, 7168), (64, 24576, 1536), (64, 32768, 512), (64, 7168, 16384), - (64, 4096, 7168), (128, 4096, 7168), (128, 7168, 18432), (128, 18432, 7168), (1024, 4096, 7168), (1024, 18432, 7168), (2048, 4096, 7168), + (4096, 4096, 7168), + ] + shapes = [ + # (64, 2112, 7168), + (64, 24576, 1536), + (64, 32768, 512), + (64, 7168, 16384), + (64, 4096, 7168), + (64, 7168, 2048), + # (128, 2112, 7168), + (128, 24576, 1536), + (128, 32768, 512), + (128, 7168, 16384), + (128, 4096, 7168), + (128, 7168, 2048), + # (4096, 2112, 7168), + (4096, 24576, 1536), + (4096, 32768, 512), + (4096, 7168, 16384), + (4096, 4096, 7168), + (4096, 7168, 2048), ] - # # Taken from - # # https://github.com/deepseek-ai/DeepGEMM?tab=readme-ov-file#normal-gemms-for-dense-models - # shapes = [ - # # (64, 2112, 7168), # Unsupported by CUTLASS - # (64, 24576, 1536), - # (64, 32768, 512), - # (64, 7168, 16384), - # (64, 4096, 7168), - # # (64, 7168, 2048), # Unsupported by DeepGEMM - - # # (128, 2112, 7168), # Unsupported by CUTLASS - # # (128, 24576, 1536), # Unsupported by DeepGEMM - # (128, 32768, 512), - # (128, 7168, 16384), - # (128, 4096, 7168), - # # (128, 7168, 2048), # Unsupported by DeepGEMM - - # # (4096, 2112, 7168), # Unsupported by CUTLASS - # (4096, 24576, 1536), - # (4096, 32768, 512), - # (4096, 7168, 16384), - # (4096, 4096, 7168), - # # (4096, 7168, 2048), # Unsupported by DeepGEMM - # ] - - all_results = {} + all_results = [] for m, n, k in shapes: - shape_key = f"m{m}_n{n}_k{k}" - all_results[shape_key] = benchmark_shape(m, n, k) - - print("\n===== BENCHMARK SUMMARY =====") - print("Matrix multiplication: C[m,n] = A[m,k] @ B[n,k].T") - print("\nAverage speedups:") - - # Calculate average speedups across all shapes - speedups = { - "DeepGEMM vs vLLM Triton": [], - "DeepGEMM vs vLLM CUTLASS": [], - "vLLM CUTLASS vs vLLM Triton": [] - } + result = benchmark_shape(m, n, k, verbose=verbose) + all_results.append(result) + + # Print results in a nicely formatted table + print("\n===== PERFORMANCE COMPARISON =====") + + # Print DeepGEMM table + deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"] + deepgemm_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["DeepGEMM"] + deepgemm_rows.append([ + shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}" + ]) + + print_table(deepgemm_headers, + deepgemm_rows, + title="DeepGEMM Implementation:") + + # Print vLLM Triton table + triton_headers = [ + "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM" + ] + triton_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["vLLM Triton"] + speedup = impl_data.get("speedup_vs_deepgemm", 1.0) + triton_rows.append([ + shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", + format_speedup(speedup) + ]) + + print_table(triton_headers, + triton_rows, + title="vLLM Triton Implementation:") + + # Print vLLM CUTLASS table + cutlass_headers = [ + "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM", + "vs Triton" + ] + cutlass_rows = [] + for result in all_results: + shape = result["shape"] + impl_data = result["implementations"]["vLLM CUTLASS"] + vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) + vs_triton = impl_data.get("speedup_vs_triton", 1.0) + cutlass_rows.append([ + shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", + format_speedup(vs_deepgemm), + format_speedup(vs_triton) + ]) + + print_table(cutlass_headers, + cutlass_rows, + title="vLLM CUTLASS Implementation:") + + # Calculate and print averages + print("\n===== AVERAGE PERFORMANCE =====") - for shape_key, results in all_results.items(): - deepgemm_time = results["DeepGEMM"]["time_ms"] - vllm_triton_time = results["vLLM Triton"]["time_ms"] - vllm_cutlass_time = results["vLLM CUTLASS"]["time_ms"] + implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] + avg_metrics = { + impl: { + "tflops": 0, + "gb_s": 0, + "time_ms": 0 + } + for impl in implementations + } - speedups["DeepGEMM vs vLLM Triton"].append(vllm_triton_time / - deepgemm_time) - speedups["DeepGEMM vs vLLM CUTLASS"].append(vllm_cutlass_time / - deepgemm_time) - speedups["vLLM CUTLASS vs vLLM Triton"].append(vllm_triton_time / - vllm_cutlass_time) + for result in all_results: + for impl in implementations: + impl_data = result["implementations"][impl] + avg_metrics[impl]["tflops"] += impl_data["tflops"] + avg_metrics[impl]["gb_s"] += impl_data["gb_s"] + avg_metrics[impl]["time_ms"] += impl_data["time_ms"] - for comparison, values in speedups.items(): - avg_speedup = sum(values) / len(values) - print(f"{comparison}: {avg_speedup:.2f}x " - f"{'faster' if avg_speedup > 1 else 'slower'}") + num_shapes = len(all_results) + avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"] + avg_rows = [] - print("\nAverage TFLOPS:") - implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] for impl in implementations: - avg_tflops = sum( - results[impl]["tflops"] - for results in all_results.values()) / len(all_results) - print(f"{impl}: {avg_tflops:.2f} TFLOPS") + avg_tflops = avg_metrics[impl]["tflops"] / num_shapes + avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes + avg_time = avg_metrics[impl]["time_ms"] / num_shapes + avg_rows.append([ + impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}" + ]) + + print_table(avg_headers, avg_rows) + + # Calculate average speedups + avg_speedups = { + "DeepGEMM vs vLLM Triton": 0, + "DeepGEMM vs vLLM CUTLASS": 0, + "vLLM CUTLASS vs vLLM Triton": 0 + } - print("\nAverage GB/s:") - implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] + for result in all_results: + deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] + vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] + vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][ + "time_ms"] + + avg_speedups[ + "DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time + avg_speedups[ + "DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time + avg_speedups[ + "vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time + + print("\n===== AVERAGE SPEEDUPS =====") + speedup_headers = ["Comparison", "Speedup"] + speedup_rows = [] + for comparison, total in avg_speedups.items(): + avg_speedup = total / num_shapes + status = "faster" if avg_speedup > 1 else "slower" + speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"]) + + print_table(speedup_headers, speedup_rows) + + # Average accuracy comparison + print("\n===== ACCURACY COMPARISON =====") + avg_diff = {impl: 0 for impl in implementations} + + for result in all_results: + for impl in implementations: + avg_diff[impl] += result["implementations"][impl]["diff"][ + "Reference"] + + diff_headers = ["Implementation", "Avg Diff vs Reference"] + diff_rows = [] for impl in implementations: - avg_mem_bw = sum( - results[impl]["mem_bw"] - for results in all_results.values()) / len(all_results) - print(f"{impl}: {avg_mem_bw:.2f} GB/s") + diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"]) - print("\nAverage accuracy difference vs reference:") - for impl in implementations: - avg_diff = sum(results[impl]["diff"]["Reference"] - for results in all_results.values()) / len(all_results) - print(f"{impl}: {avg_diff:.6f}") + print_table(diff_headers, diff_rows) if __name__ == "__main__": - run_benchmarks() + run_benchmarks(verbose=False) diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py deleted file mode 100644 index 4af961a62adf..000000000000 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm_table.py +++ /dev/null @@ -1,465 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# fmt: off -# ruff: noqa: E501 -import time -from typing import Dict, Tuple - -# Import DeepGEMM functions -import deep_gemm -import torch -import triton -from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor - -# Import vLLM functions -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, w8a8_block_fp8_matmul) - - -# Copied from -# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9 -def per_token_cast_to_fp8( - x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Convert tensor to FP8 format with per-token scaling.""" - assert x.dim() == 2 and x.size(1) % 128 == 0 - m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to( - torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) - - -# Copied from -# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17 -def per_block_cast_to_fp8( - x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Convert tensor to FP8 format with per-block scaling.""" - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( - x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - - -def benchmark_shape(m: int, - n: int, - k: int, - warmup: int = 100, - repeat: int = 10000, - verbose: bool = False) -> Dict: - """Benchmark all implementations for a specific (m, n, k) shape.""" - if verbose: - print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") - - # Create test tensors - A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - - # Reference result in BF16 - torch.cuda.synchronize() - C_ref = A @ B.t() - - # Pre-quantize B for all implementations - # (weights can be pre-quantized offline) - B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B) - B_vllm, B_scale_vllm = per_block_cast_to_fp8(B) - - # Block size configuration - block_size = [128, 128] - - # Pre-quantize A for all implementations - A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) - A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) - A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - A, block_size[1], column_major_scales=True) - - # === DeepGEMM Implementation === - def deepgemm_gemm(): - # A quantization is inside the loop as it depends on activations - # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) - # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( - # A, block_size[1]) - # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), - (B_deepgemm, B_scale_deepgemm), - C_deepgemm) - return C_deepgemm - - # === vLLM Triton Implementation === - def vllm_triton_gemm(): - # A quantization is inside the loop as it depends on activations - # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) - return w8a8_block_fp8_matmul(A_vllm, - B_vllm, - A_scale_vllm, - B_scale_vllm, - block_size, - output_dtype=torch.bfloat16) - - # === vLLM CUTLASS Implementation === - def vllm_cutlass_gemm(): - # A quantization is inside the loop as it depends on activations - # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - # A, block_size[1], column_major_scales=True) - return ops.cutlass_scaled_mm(A_vllm_cutlass, - B_vllm.T, - scale_a=A_scale_vllm_cutlass, - scale_b=B_scale_vllm.T, - out_dtype=torch.bfloat16) - - # Run correctness check first - if verbose: - print("Running correctness check...") - C_deepgemm = deepgemm_gemm() - C_vllm_triton = vllm_triton_gemm() - C_vllm_cutlass = vllm_cutlass_gemm() - - deepgemm_diff = calc_diff(C_deepgemm, C_ref) - vllm_triton_diff = calc_diff(C_vllm_triton, C_ref) - vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref) - - if verbose: - print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") - print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") - print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") - print("vLLM Triton vs DeepGEMM difference: " - f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") - print("vLLM CUTLASS vs DeepGEMM difference: " - f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") - - # Benchmark implementations - implementations = { - "DeepGEMM": deepgemm_gemm, - "vLLM Triton": vllm_triton_gemm, - "vLLM CUTLASS": vllm_cutlass_gemm - } - - benchmark_results = { - "shape": { - "m": m, - "n": n, - "k": k - }, - "implementations": {} - } - - for name, func in implementations.items(): - # Warmup - for _ in range(warmup): - func() - torch.cuda.synchronize() - - # Timing loop - torch.cuda.synchronize() - start = time.time() - for _ in range(repeat): - func() - torch.cuda.synchronize() - end = time.time() - - # Calculate timing and TFLOPS - avg_time_ms = (end - start) / repeat * 1000 - avg_time_us = avg_time_ms * 1000 - tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12 - gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3) - - benchmark_results["implementations"][name] = { - "time_ms": avg_time_ms, - "time_us": avg_time_us, - "tflops": tflops, - "gb_s": gb_s, - "diff": { - "DeepGEMM": - 0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm), - "Reference": - deepgemm_diff if name == "DeepGEMM" else - (vllm_triton_diff - if name == "vLLM Triton" else vllm_cutlass_diff) - } - } - - if verbose: - print( - f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s" - ) - - # Calculate speedups - baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] - for name, data in benchmark_results["implementations"].items(): - if name != "DeepGEMM": - speedup = baseline / data["time_ms"] - benchmark_results["implementations"][name][ - "speedup_vs_deepgemm"] = speedup - if verbose: - print(f"DeepGEMM is {1/speedup:.2f}x " - f"{'faster' if 1/speedup > 1 else 'slower'} than {name}") - - vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][ - "time_ms"] - vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][ - "time_ms"] - cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time - benchmark_results["implementations"]["vLLM CUTLASS"][ - "speedup_vs_triton"] = cutlass_vs_triton - if verbose: - print( - f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " - f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton" - ) - - return benchmark_results - - -def format_table_row(values, widths): - """Format a row with specified column widths.""" - return "| " + " | ".join(f"{val:{w}}" - for val, w in zip(values, widths)) + " |" - - -def print_table(headers, rows, title=None): - """Print a table with headers and rows.""" - if title: - print(f"\n{title}") - - # Calculate column widths based on headers and data - widths = [ - max(len(str(h)), max(len(str(row[i])) for row in rows)) - for i, h in enumerate(headers) - ] - - # Create separator line - separator = "+-" + "-+-".join("-" * w for w in widths) + "-+" - - # Print table - print(separator) - print(format_table_row(headers, widths)) - print(separator) - for row in rows: - print(format_table_row(row, widths)) - print(separator) - - -def format_speedup(value): - """Format speedup value with indicator if it's faster or slower.""" - return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}" - - -def run_benchmarks(verbose: bool = False): - """Run benchmarks for a set of common shapes.""" - print("===== STARTING FP8 GEMM BENCHMARK =====") - - # Make sure we're using the GPU - if not torch.cuda.is_available(): - print("CUDA not available! Tests require GPU.") - return - - # Print system information - print(f"PyTorch version: {torch.__version__}") - print(f"CUDA version: {torch.version.cuda}") - print(f"Triton version: {triton.__version__}") - print(f"Using device: {torch.cuda.get_device_name()}") - - # Enable TF32 for better performance - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - # Set seeds for reproducibility - torch.manual_seed(42) - torch.cuda.manual_seed(42) - - # Define benchmark shapes (m, n, k) - shapes = [ - (8, 4096, 7168), - (8, 7168, 18432), - (8, 18432, 7168), - (64, 4096, 7168), - (64, 7168, 18432), - (64, 18432, 7168), - (64, 24576, 1536), - (64, 32768, 512), - (64, 7168, 16384), - (128, 4096, 7168), - (128, 7168, 18432), - (128, 18432, 7168), - (1024, 4096, 7168), - (1024, 18432, 7168), - (2048, 4096, 7168), - (4096, 4096, 7168), - ] - shapes = [ - # (64, 2112, 7168), - (64, 24576, 1536), - (64, 32768, 512), - (64, 7168, 16384), - (64, 4096, 7168), - (64, 7168, 2048), - # (128, 2112, 7168), - (128, 24576, 1536), - (128, 32768, 512), - (128, 7168, 16384), - (128, 4096, 7168), - (128, 7168, 2048), - # (4096, 2112, 7168), - (4096, 24576, 1536), - (4096, 32768, 512), - (4096, 7168, 16384), - (4096, 4096, 7168), - (4096, 7168, 2048), - ] - - all_results = [] - for m, n, k in shapes: - result = benchmark_shape(m, n, k, verbose=verbose) - all_results.append(result) - - # Print results in a nicely formatted table - print("\n===== PERFORMANCE COMPARISON =====") - - # Print DeepGEMM table - deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"] - deepgemm_rows = [] - for result in all_results: - shape = result["shape"] - impl_data = result["implementations"]["DeepGEMM"] - deepgemm_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}" - ]) - - print_table(deepgemm_headers, - deepgemm_rows, - title="DeepGEMM Implementation:") - - # Print vLLM Triton table - triton_headers = [ - "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM" - ] - triton_rows = [] - for result in all_results: - shape = result["shape"] - impl_data = result["implementations"]["vLLM Triton"] - speedup = impl_data.get("speedup_vs_deepgemm", 1.0) - triton_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", - format_speedup(speedup) - ]) - - print_table(triton_headers, - triton_rows, - title="vLLM Triton Implementation:") - - # Print vLLM CUTLASS table - cutlass_headers = [ - "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM", - "vs Triton" - ] - cutlass_rows = [] - for result in all_results: - shape = result["shape"] - impl_data = result["implementations"]["vLLM CUTLASS"] - vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) - vs_triton = impl_data.get("speedup_vs_triton", 1.0) - cutlass_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", - format_speedup(vs_deepgemm), - format_speedup(vs_triton) - ]) - - print_table(cutlass_headers, - cutlass_rows, - title="vLLM CUTLASS Implementation:") - - # Calculate and print averages - print("\n===== AVERAGE PERFORMANCE =====") - - implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] - avg_metrics = { - impl: { - "tflops": 0, - "gb_s": 0, - "time_ms": 0 - } - for impl in implementations - } - - for result in all_results: - for impl in implementations: - impl_data = result["implementations"][impl] - avg_metrics[impl]["tflops"] += impl_data["tflops"] - avg_metrics[impl]["gb_s"] += impl_data["gb_s"] - avg_metrics[impl]["time_ms"] += impl_data["time_ms"] - - num_shapes = len(all_results) - avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"] - avg_rows = [] - - for impl in implementations: - avg_tflops = avg_metrics[impl]["tflops"] / num_shapes - avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes - avg_time = avg_metrics[impl]["time_ms"] / num_shapes - avg_rows.append([ - impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}" - ]) - - print_table(avg_headers, avg_rows) - - # Calculate average speedups - avg_speedups = { - "DeepGEMM vs vLLM Triton": 0, - "DeepGEMM vs vLLM CUTLASS": 0, - "vLLM CUTLASS vs vLLM Triton": 0 - } - - for result in all_results: - deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] - vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] - vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][ - "time_ms"] - - avg_speedups[ - "DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time - avg_speedups[ - "DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time - avg_speedups[ - "vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time - - print("\n===== AVERAGE SPEEDUPS =====") - speedup_headers = ["Comparison", "Speedup"] - speedup_rows = [] - for comparison, total in avg_speedups.items(): - avg_speedup = total / num_shapes - status = "faster" if avg_speedup > 1 else "slower" - speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"]) - - print_table(speedup_headers, speedup_rows) - - # Average accuracy comparison - print("\n===== ACCURACY COMPARISON =====") - avg_diff = {impl: 0 for impl in implementations} - - for result in all_results: - for impl in implementations: - avg_diff[impl] += result["implementations"][impl]["diff"][ - "Reference"] - - diff_headers = ["Implementation", "Avg Diff vs Reference"] - diff_rows = [] - for impl in implementations: - diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"]) - - print_table(diff_headers, diff_rows) - - -if __name__ == "__main__": - run_benchmarks(verbose=False)