Skip to content

Commit 161229f

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Prototype persistent + TMA kernel with warp specialization (#385)
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Reviewed By: njriasan Differential Revision: D81470285
1 parent 069c05a commit 161229f

File tree

2 files changed

+207
-1
lines changed

2 files changed

+207
-1
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import torch._inductor.config as inductor_config
88
import triton
99

10+
from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma
11+
from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda
12+
1013
from tritonbench.utils.triton_op import (
1114
BenchmarkOperator,
1215
BenchmarkOperatorMetrics,
@@ -19,6 +22,10 @@
1922

2023
from .tutorial import matmul as tutorial_matmul
2124

25+
IS_B200 = is_cuda() and get_nvidia_gpu_model() == "NVIDIA B200"
26+
27+
torch._dynamo.config.recompile_limit = 10000
28+
2229
logger = logging.getLogger(__name__)
2330
try:
2431
from .persistent import (
@@ -169,6 +176,14 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
169176

170177
return lambda: compiled(a, b)
171178

179+
if IS_B200:
180+
181+
@register_benchmark(enabled=True)
182+
def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
183+
return lambda: blackwell_persistent_tma(
184+
a, b.T, scale_a, scale_b.T, self._get_dtype()
185+
)
186+
172187
@register_benchmark()
173188
def triton_fp8_gemm(self, a, b, scale_a, scale_b):
174189
return lambda: tutorial_matmul(a, b)

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 192 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from functools import lru_cache
2+
from typing import Optional
23

34
import torch
45
import triton
56
import triton.language as tl
6-
import triton.tools.experimental_descriptor
77

88
from tritonbench.utils.env_utils import is_cuda
9+
from tritonbench.utils.triton_utils import has_experimental_descriptor
10+
11+
if has_experimental_descriptor():
12+
import triton.tools.experimental_descriptor
913

1014
cublas = None
1115
if is_cuda():
@@ -289,6 +293,36 @@ def matmul_configs():
289293
}
290294

291295

296+
def matmul_configs_blackwell():
297+
# Autotuner does not work with TMA. Use manual config.
298+
return {
299+
torch.float8_e4m3fn: {
300+
"BLOCK_SIZE_M": 128,
301+
"BLOCK_SIZE_N": 128,
302+
"BLOCK_SIZE_K": 128,
303+
"GROUP_SIZE_M": 8,
304+
"num_stages": 4,
305+
"num_warps": 4, # Note: num_warps >= 4 required for TMA
306+
},
307+
torch.float16: {
308+
"BLOCK_SIZE_M": 128,
309+
"BLOCK_SIZE_N": 256,
310+
"BLOCK_SIZE_K": 64,
311+
"GROUP_SIZE_M": 8,
312+
"num_stages": 2,
313+
"num_warps": 2,
314+
},
315+
torch.bfloat16: {
316+
"BLOCK_SIZE_M": 128,
317+
"BLOCK_SIZE_N": 256,
318+
"BLOCK_SIZE_K": 64,
319+
"GROUP_SIZE_M": 8,
320+
"num_stages": 2,
321+
"num_warps": 2,
322+
},
323+
}
324+
325+
292326
def allocate_matmul_tma(a, b):
293327
configs = matmul_configs()
294328

@@ -364,3 +398,160 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
364398
num_warps=configs[dtype]["num_warps"], #
365399
)
366400
return c
401+
402+
403+
# Blackwell Persistent + TMA
404+
# Restrictions:
405+
# - (K, N) must be a multiple of 16 on B200 for all benchmarks
406+
# - num_warps >= 4
407+
# - TMA instructions expect at least a 128-thread group
408+
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
409+
410+
411+
def blackwell_persistent_tma(a, b, scale_a, scale_b, acc_dtype):
412+
configs = matmul_configs_blackwell()
413+
414+
# Check constraints.
415+
assert (
416+
a.shape[1] == b.shape[1]
417+
), "Incompatible dimensions" # a.shape = (M, K), b.shape = (N, K)
418+
assert a.dtype == b.dtype, "Incompatible dtypes"
419+
420+
M, K = a.shape
421+
N, K = b.shape
422+
shape_dtype = a.dtype # low-precision dtype, e.g. fp8
423+
424+
NUM_SMS = torch.cuda.get_device_properties(
425+
torch.cuda.current_device()
426+
).multi_processor_count
427+
428+
c = torch.zeros((M, N), device=a.device, dtype=acc_dtype)
429+
430+
def alloc_fn(size: int, align: int, stream: Optional[int]):
431+
return torch.empty(size, dtype=torch.int8, device=a.device)
432+
433+
if hasattr(triton, "set_allocator"):
434+
triton.set_allocator(alloc_fn)
435+
else:
436+
return c
437+
438+
if acc_dtype == torch.float16:
439+
acc_dtype_tl = tl.float16
440+
elif acc_dtype == torch.bfloat16:
441+
acc_dtype_tl = tl.bfloat16
442+
else:
443+
raise NotImplementedError(
444+
"Output types other than torch.float16 and torch.bfloat16 unsupported for FP8 Blackwell persistent + TMA kernels"
445+
)
446+
447+
grid = lambda META: (
448+
min(
449+
NUM_SMS,
450+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
451+
),
452+
)
453+
blackwell_persistent_tma_kernel[grid](
454+
a,
455+
b,
456+
c, #
457+
M,
458+
N,
459+
K, #
460+
scale_a.item(), #
461+
scale_b.item(), #
462+
BLOCK_SIZE_M=configs[shape_dtype]["BLOCK_SIZE_M"], #
463+
BLOCK_SIZE_N=configs[shape_dtype]["BLOCK_SIZE_N"], #
464+
BLOCK_SIZE_K=configs[shape_dtype]["BLOCK_SIZE_K"], #
465+
GROUP_SIZE_M=configs[shape_dtype]["GROUP_SIZE_M"], #
466+
ACC_TYPE=acc_dtype_tl,
467+
NUM_SMS=NUM_SMS, #
468+
num_stages=configs[shape_dtype]["num_stages"], #
469+
num_warps=configs[shape_dtype]["num_warps"], #
470+
)
471+
return c
472+
473+
474+
@triton.jit
475+
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
476+
group_id = tile_id // num_pid_in_group
477+
first_pid_m = group_id * GROUP_SIZE_M
478+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
479+
pid_m = first_pid_m + (tile_id % group_size_m)
480+
pid_n = (tile_id % num_pid_in_group) // group_size_m
481+
return pid_m, pid_n
482+
483+
484+
@triton.jit(launch_metadata=_matmul_launch_metadata)
485+
def blackwell_persistent_tma_kernel(
486+
a,
487+
b,
488+
acc,
489+
M, #
490+
N, #
491+
K, #
492+
scale_a, #
493+
scale_b, #
494+
BLOCK_SIZE_M: tl.constexpr, #
495+
BLOCK_SIZE_N: tl.constexpr, #
496+
BLOCK_SIZE_K: tl.constexpr, #
497+
GROUP_SIZE_M: tl.constexpr, #
498+
ACC_TYPE: tl.constexpr,
499+
NUM_SMS: tl.constexpr,
500+
): #
501+
start_pid = tl.program_id(axis=0)
502+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
503+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
504+
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
505+
num_tiles = num_pid_m * num_pid_n
506+
507+
a_desc = tl.make_tensor_descriptor(
508+
a,
509+
shape=[M, K],
510+
strides=[K, 1],
511+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
512+
)
513+
b_desc = tl.make_tensor_descriptor(
514+
b,
515+
shape=[N, K],
516+
strides=[K, 1],
517+
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
518+
)
519+
acc_desc = tl.make_tensor_descriptor(
520+
acc,
521+
shape=[M, N],
522+
strides=[N, 1],
523+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
524+
)
525+
526+
tile_id_c = start_pid - NUM_SMS
527+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
528+
529+
for tile_id in tl.range(
530+
start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True
531+
):
532+
pid_m, pid_n = _compute_pid(
533+
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
534+
)
535+
offs_am = pid_m * BLOCK_SIZE_M
536+
offs_bn = pid_n * BLOCK_SIZE_N
537+
538+
accumulator = tl.zeros(
539+
(BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
540+
) # accumulate in high precision (fp32) for high accuracy
541+
for ki in range(k_tiles):
542+
offs_k = ki * BLOCK_SIZE_K
543+
a_block = a_desc.load([offs_am, offs_k])
544+
b_block = b_desc.load([offs_bn, offs_k])
545+
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
546+
547+
accumulator *= scale_a * scale_b # currently only supports per-tensor scaling
548+
549+
tile_id_c += NUM_SMS
550+
pid_m, pid_n = _compute_pid(
551+
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
552+
)
553+
offs_cm = pid_m * BLOCK_SIZE_M
554+
offs_cn = pid_n * BLOCK_SIZE_N
555+
556+
c = accumulator.to(ACC_TYPE)
557+
acc_desc.store([offs_cm, offs_cn], c)

0 commit comments

Comments
 (0)