diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 3a0897e97..4c1eb3b4a 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -7,6 +7,9 @@ import torch._inductor.config as inductor_config import triton +from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma +from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda + from tritonbench.utils.triton_op import ( BenchmarkOperator, BenchmarkOperatorMetrics, @@ -19,6 +22,10 @@ from .tutorial import matmul as tutorial_matmul +IS_B200 = is_cuda() and get_nvidia_gpu_model() == "NVIDIA B200" + +torch._dynamo.config.recompile_limit = 10000 + logger = logging.getLogger(__name__) try: from .persistent import ( @@ -169,6 +176,14 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable: return lambda: compiled(a, b) + if IS_B200: + + @register_benchmark(enabled=True) + def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b): + return lambda: blackwell_persistent_tma( + a, b.T, scale_a, scale_b.T, self._get_dtype() + ) + @register_benchmark() def triton_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: tutorial_matmul(a, b) diff --git a/tritonbench/operators/fp8_gemm/persistent.py b/tritonbench/operators/fp8_gemm/persistent.py index e9789f339..bf3f2b029 100644 --- a/tritonbench/operators/fp8_gemm/persistent.py +++ b/tritonbench/operators/fp8_gemm/persistent.py @@ -1,11 +1,15 @@ from functools import lru_cache +from typing import Optional import torch import triton import triton.language as tl -import triton.tools.experimental_descriptor from tritonbench.utils.env_utils import is_cuda +from tritonbench.utils.triton_utils import has_experimental_descriptor + +if has_experimental_descriptor(): + import triton.tools.experimental_descriptor cublas = None if is_cuda(): @@ -289,6 +293,36 @@ def matmul_configs(): } +def matmul_configs_blackwell(): + # Autotuner does not work with TMA. Use manual config. + return { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 4, + "num_warps": 4, # Note: num_warps >= 4 required for TMA + }, + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 2, + "num_warps": 2, + }, + torch.bfloat16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 2, + "num_warps": 2, + }, + } + + def allocate_matmul_tma(a, b): configs = matmul_configs() @@ -364,3 +398,160 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c): num_warps=configs[dtype]["num_warps"], # ) return c + + +# Blackwell Persistent + TMA +# Restrictions: +# - (K, N) must be a multiple of 16 on B200 for all benchmarks +# - num_warps >= 4 +# - TMA instructions expect at least a 128-thread group +# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps + + +def blackwell_persistent_tma(a, b, scale_a, scale_b, acc_dtype): + configs = matmul_configs_blackwell() + + # Check constraints. + assert ( + a.shape[1] == b.shape[1] + ), "Incompatible dimensions" # a.shape = (M, K), b.shape = (N, K) + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + shape_dtype = a.dtype # low-precision dtype, e.g. fp8 + + NUM_SMS = torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + + c = torch.zeros((M, N), device=a.device, dtype=acc_dtype) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device=a.device) + + if hasattr(triton, "set_allocator"): + triton.set_allocator(alloc_fn) + else: + return c + + if acc_dtype == torch.float16: + acc_dtype_tl = tl.float16 + elif acc_dtype == torch.bfloat16: + acc_dtype_tl = tl.bfloat16 + else: + raise NotImplementedError( + "Output types other than torch.float16 and torch.bfloat16 unsupported for FP8 Blackwell persistent + TMA kernels" + ) + + grid = lambda META: ( + min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), + ) + blackwell_persistent_tma_kernel[grid]( + a, + b, + c, # + M, + N, + K, # + scale_a.item(), # + scale_b.item(), # + BLOCK_SIZE_M=configs[shape_dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[shape_dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[shape_dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[shape_dtype]["GROUP_SIZE_M"], # + ACC_TYPE=acc_dtype_tl, + NUM_SMS=NUM_SMS, # + num_stages=configs[shape_dtype]["num_stages"], # + num_warps=configs[shape_dtype]["num_warps"], # + ) + return c + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def blackwell_persistent_tma_kernel( + a, + b, + acc, + M, # + N, # + K, # + scale_a, # + scale_b, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACC_TYPE: tl.constexpr, + NUM_SMS: tl.constexpr, +): # + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + a_desc = tl.make_tensor_descriptor( + a, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl.make_tensor_descriptor( + b, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + acc_desc = tl.make_tensor_descriptor( + acc, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range( + start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True + ): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS + ) + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 + ) # accumulate in high precision (fp32) for high accuracy + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a_block = a_desc.load([offs_am, offs_k]) + b_block = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32) + + accumulator *= scale_a * scale_b # currently only supports per-tensor scaling + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid( + tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS + ) + offs_cm = pid_m * BLOCK_SIZE_M + offs_cn = pid_n * BLOCK_SIZE_N + + c = accumulator.to(ACC_TYPE) + acc_desc.store([offs_cm, offs_cn], c)