Skip to content

Commit 3cddb89

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Support per-row scaling (#424)
Summary: Support per-row scaling for the FP8 Blackwell persistent + TMA kernel with warp specialization. Differential Revision: D82516347
1 parent 5a2e267 commit 3cddb89

File tree

2 files changed

+35
-19
lines changed

2 files changed

+35
-19
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
180180

181181
@register_benchmark(enabled=True)
182182
def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
183-
return lambda: blackwell_persistent_tma(
184-
a, b.T, scale_a, scale_b.T, self._get_dtype()
185-
)
183+
return lambda: blackwell_persistent_tma(a, b.T, scale_a, scale_b.T, self._get_dtype(), self.extra_args.scaling_rowwise)
186184

187185
@register_benchmark()
188186
def triton_fp8_gemm(self, a, b, scale_a, scale_b):

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
405405
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
406406

407407

408-
def blackwell_persistent_tma(a, b, scale_a, scale_b, acc_dtype):
408+
def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise):
409409
configs = matmul_configs_blackwell()
410410

411411
# Check constraints.
@@ -454,8 +454,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
454454
M,
455455
N,
456456
K, #
457-
scale_a.item(), #
458-
scale_b.item(), #
457+
scale_a_ptr, #
458+
scale_b_ptr, #
459459
BLOCK_SIZE_M=configs[shape_dtype]["BLOCK_SIZE_M"], #
460460
BLOCK_SIZE_N=configs[shape_dtype]["BLOCK_SIZE_N"], #
461461
BLOCK_SIZE_K=configs[shape_dtype]["BLOCK_SIZE_K"], #
@@ -464,12 +464,13 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
464464
NUM_SMS=NUM_SMS, #
465465
num_stages=configs[shape_dtype]["num_stages"], #
466466
num_warps=configs[shape_dtype]["num_warps"], #
467+
SCALING_ROWWISE=scaling_rowwise,
467468
)
468469
return c
469470

470471

471472
@triton.jit
472-
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
473+
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M):
473474
group_id = tile_id // num_pid_in_group
474475
first_pid_m = group_id * GROUP_SIZE_M
475476
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
@@ -485,14 +486,15 @@ def blackwell_persistent_tma_kernel(
485486
M, #
486487
N, #
487488
K, #
488-
scale_a, #
489-
scale_b, #
489+
scale_a_ptr, #
490+
scale_b_ptr, #
490491
BLOCK_SIZE_M: tl.constexpr, #
491492
BLOCK_SIZE_N: tl.constexpr, #
492493
BLOCK_SIZE_K: tl.constexpr, #
493494
GROUP_SIZE_M: tl.constexpr, #
494495
ACC_TYPE: tl.constexpr,
495496
NUM_SMS: tl.constexpr,
497+
SCALING_ROWWISE: tl.constexpr, #
496498
): #
497499
start_pid = tl.program_id(axis=0)
498500
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -522,12 +524,17 @@ def blackwell_persistent_tma_kernel(
522524
tile_id_c = start_pid - NUM_SMS
523525
num_pid_in_group = GROUP_SIZE_M * num_pid_n
524526

525-
for tile_id in tl.range(
526-
start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True
527-
):
528-
pid_m, pid_n = _compute_pid(
529-
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
530-
)
527+
if SCALING_ROWWISE:
528+
# For row-wise scaling, we'll use the pointers as-is
529+
scale_a = scale_a_ptr
530+
scale_b = scale_b_ptr
531+
else:
532+
# For per-tensor scaling, we'll load the scalar values
533+
scale_a = tl.load(scale_a_ptr)
534+
scale_b = tl.load(scale_b_ptr)
535+
536+
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True):
537+
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M)
531538
offs_am = pid_m * BLOCK_SIZE_M
532539
offs_bn = pid_n * BLOCK_SIZE_N
533540

@@ -540,12 +547,23 @@ def blackwell_persistent_tma_kernel(
540547
b_block = b_desc.load([offs_bn, offs_k])
541548
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
542549

543-
accumulator *= scale_a * scale_b # currently only supports per-tensor scaling
550+
if SCALING_ROWWISE:
551+
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
552+
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)
553+
554+
scale_a_block = tl.load(scale_a + offs_scale_m, mask=offs_am < M, other=0.0)
555+
scale_b_block = tl.load(scale_b + offs_scale_n, mask=offs_bn < N, other=0.0)
556+
557+
a_scales = scale_a_block[:, None]
558+
b_scales = scale_b_block[None, :]
559+
else:
560+
a_scales = scale_a
561+
b_scales = scale_b
562+
563+
accumulator *= a_scales * b_scales
544564

545565
tile_id_c += NUM_SMS
546-
pid_m, pid_n = _compute_pid(
547-
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
548-
)
566+
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M)
549567
offs_cm = pid_m * BLOCK_SIZE_M
550568
offs_cn = pid_n * BLOCK_SIZE_N
551569

0 commit comments

Comments
 (0)