Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tritonbench/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import torch._inductor.config as inductor_config
import triton

from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma
from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand All @@ -19,6 +22,10 @@

from .tutorial import matmul as tutorial_matmul

IS_B200 = is_cuda() and get_nvidia_gpu_model() == "NVIDIA B200"

torch._dynamo.config.recompile_limit = 10000

logger = logging.getLogger(__name__)
try:
from .persistent import (
Expand Down Expand Up @@ -169,6 +176,14 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:

return lambda: compiled(a, b)

if IS_B200:

@register_benchmark(enabled=True)
def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
return lambda: blackwell_persistent_tma(
a, b.T, scale_a, scale_b.T, self._get_dtype()
)

@register_benchmark()
def triton_fp8_gemm(self, a, b, scale_a, scale_b):
return lambda: tutorial_matmul(a, b)
Expand Down
193 changes: 192 additions & 1 deletion tritonbench/operators/fp8_gemm/persistent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from functools import lru_cache
from typing import Optional

import torch
import triton
import triton.language as tl
import triton.tools.experimental_descriptor

from tritonbench.utils.env_utils import is_cuda
from tritonbench.utils.triton_utils import has_experimental_descriptor

if has_experimental_descriptor():
import triton.tools.experimental_descriptor

cublas = None
if is_cuda():
Expand Down Expand Up @@ -289,6 +293,36 @@ def matmul_configs():
}


def matmul_configs_blackwell():
# Autotuner does not work with TMA. Use manual config.
return {
torch.float8_e4m3fn: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_stages": 4,
"num_warps": 4, # Note: num_warps >= 4 required for TMA
},
torch.float16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 2,
"num_warps": 2,
},
torch.bfloat16: {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_stages": 2,
"num_warps": 2,
},
}


def allocate_matmul_tma(a, b):
configs = matmul_configs()

Expand Down Expand Up @@ -364,3 +398,160 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
num_warps=configs[dtype]["num_warps"], #
)
return c


# Blackwell Persistent + TMA
# Restrictions:
# - (K, N) must be a multiple of 16 on B200 for all benchmarks
# - num_warps >= 4
# - TMA instructions expect at least a 128-thread group
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps


def blackwell_persistent_tma(a, b, scale_a, scale_b, acc_dtype):
configs = matmul_configs_blackwell()

# Check constraints.
assert (
a.shape[1] == b.shape[1]
), "Incompatible dimensions" # a.shape = (M, K), b.shape = (N, K)
assert a.dtype == b.dtype, "Incompatible dtypes"

M, K = a.shape
N, K = b.shape
shape_dtype = a.dtype # low-precision dtype, e.g. fp8

NUM_SMS = torch.cuda.get_device_properties(
torch.cuda.current_device()
).multi_processor_count

c = torch.zeros((M, N), device=a.device, dtype=acc_dtype)

def alloc_fn(size: int, align: int, stream: Optional[int]):
return torch.empty(size, dtype=torch.int8, device=a.device)

if hasattr(triton, "set_allocator"):
triton.set_allocator(alloc_fn)
else:
return c

if acc_dtype == torch.float16:
acc_dtype_tl = tl.float16
elif acc_dtype == torch.bfloat16:
acc_dtype_tl = tl.bfloat16
else:
raise NotImplementedError(
"Output types other than torch.float16 and torch.bfloat16 unsupported for FP8 Blackwell persistent + TMA kernels"
)

grid = lambda META: (
min(
NUM_SMS,
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
),
)
blackwell_persistent_tma_kernel[grid](
a,
b,
c, #
M,
N,
K, #
scale_a.item(), #
scale_b.item(), #
BLOCK_SIZE_M=configs[shape_dtype]["BLOCK_SIZE_M"], #
BLOCK_SIZE_N=configs[shape_dtype]["BLOCK_SIZE_N"], #
BLOCK_SIZE_K=configs[shape_dtype]["BLOCK_SIZE_K"], #
GROUP_SIZE_M=configs[shape_dtype]["GROUP_SIZE_M"], #
ACC_TYPE=acc_dtype_tl,
NUM_SMS=NUM_SMS, #
num_stages=configs[shape_dtype]["num_stages"], #
num_warps=configs[shape_dtype]["num_warps"], #
)
return c


@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n


@triton.jit(launch_metadata=_matmul_launch_metadata)
def blackwell_persistent_tma_kernel(
a,
b,
acc,
M, #
N, #
K, #
scale_a, #
scale_b, #
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
ACC_TYPE: tl.constexpr,
NUM_SMS: tl.constexpr,
): #
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n

a_desc = tl.make_tensor_descriptor(
a,
shape=[M, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl.make_tensor_descriptor(
b,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
acc_desc = tl.make_tensor_descriptor(
acc,
shape=[M, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)

tile_id_c = start_pid - NUM_SMS
num_pid_in_group = GROUP_SIZE_M * num_pid_n

for tile_id in tl.range(
start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True
):
pid_m, pid_n = _compute_pid(
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N

accumulator = tl.zeros(
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
) # accumulate in high precision (fp32) for high accuracy
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K
a_block = a_desc.load([offs_am, offs_k])
b_block = b_desc.load([offs_bn, offs_k])
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)

accumulator *= scale_a * scale_b # currently only supports per-tensor scaling

tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
offs_cm = pid_m * BLOCK_SIZE_M
offs_cn = pid_n * BLOCK_SIZE_N

c = accumulator.to(ACC_TYPE)
acc_desc.store([offs_cm, offs_cn], c)