Skip to content

Commit df34b51

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Support per-row scaling (#424)
Summary: Support per-row scaling for the FP8 Blackwell persistent + TMA kernel with warp specialization. Reviewed By: njriasan Differential Revision: D82516347
1 parent 434ca4c commit df34b51

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,12 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
181181
@register_benchmark(enabled=True)
182182
def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
183183
return lambda: blackwell_persistent_tma(
184-
a, b.T, scale_a, scale_b.T, self._get_dtype()
184+
a,
185+
b.T,
186+
scale_a,
187+
scale_b.T,
188+
self._get_dtype(),
189+
self.extra_args.scaling_rowwise,
185190
)
186191

187192
@register_benchmark()

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,9 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
408408
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
409409

410410

411-
def blackwell_persistent_tma(a, b, scale_a, scale_b, acc_dtype):
411+
def blackwell_persistent_tma(
412+
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
413+
):
412414
configs = matmul_configs_blackwell()
413415

414416
# Check constraints.
@@ -457,8 +459,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
457459
M,
458460
N,
459461
K, #
460-
scale_a.item(), #
461-
scale_b.item(), #
462+
scale_a_ptr, #
463+
scale_b_ptr, #
462464
BLOCK_SIZE_M=configs[shape_dtype]["BLOCK_SIZE_M"], #
463465
BLOCK_SIZE_N=configs[shape_dtype]["BLOCK_SIZE_N"], #
464466
BLOCK_SIZE_K=configs[shape_dtype]["BLOCK_SIZE_K"], #
@@ -467,12 +469,13 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
467469
NUM_SMS=NUM_SMS, #
468470
num_stages=configs[shape_dtype]["num_stages"], #
469471
num_warps=configs[shape_dtype]["num_warps"], #
472+
SCALING_ROWWISE=scaling_rowwise,
470473
)
471474
return c
472475

473476

474477
@triton.jit
475-
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
478+
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M):
476479
group_id = tile_id // num_pid_in_group
477480
first_pid_m = group_id * GROUP_SIZE_M
478481
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
@@ -489,14 +492,15 @@ def blackwell_persistent_tma_kernel(
489492
M, #
490493
N, #
491494
K, #
492-
scale_a, #
493-
scale_b, #
495+
scale_a_ptr, #
496+
scale_b_ptr, #
494497
BLOCK_SIZE_M: tl.constexpr, #
495498
BLOCK_SIZE_N: tl.constexpr, #
496499
BLOCK_SIZE_K: tl.constexpr, #
497500
GROUP_SIZE_M: tl.constexpr, #
498501
ACC_TYPE: tl.constexpr,
499502
NUM_SMS: tl.constexpr,
503+
SCALING_ROWWISE: tl.constexpr, #
500504
): #
501505
start_pid = tl.program_id(axis=0)
502506
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -526,11 +530,20 @@ def blackwell_persistent_tma_kernel(
526530
tile_id_c = start_pid - NUM_SMS
527531
num_pid_in_group = GROUP_SIZE_M * num_pid_n
528532

533+
if SCALING_ROWWISE:
534+
# For row-wise scaling, we'll use the pointers as-is
535+
scale_a = scale_a_ptr
536+
scale_b = scale_b_ptr
537+
else:
538+
# For per-tensor scaling, we'll load the scalar values
539+
scale_a = tl.load(scale_a_ptr)
540+
scale_b = tl.load(scale_b_ptr)
541+
529542
for tile_id in tl.range(
530543
start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True
531544
):
532545
pid_m, pid_n = _compute_pid(
533-
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
546+
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M
534547
)
535548
offs_am = pid_m * BLOCK_SIZE_M
536549
offs_bn = pid_n * BLOCK_SIZE_N
@@ -544,12 +557,23 @@ def blackwell_persistent_tma_kernel(
544557
b_block = b_desc.load([offs_bn, offs_k])
545558
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
546559

547-
accumulator *= scale_a * scale_b # currently only supports per-tensor scaling
560+
if SCALING_ROWWISE:
561+
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
562+
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)
563+
564+
scale_a_block = tl.load(scale_a + offs_scale_m, mask=offs_am < M, other=0.0)
565+
scale_b_block = tl.load(scale_b + offs_scale_n, mask=offs_bn < N, other=0.0)
566+
567+
a_scales = scale_a_block[:, None]
568+
b_scales = scale_b_block[None, :]
569+
else:
570+
a_scales = scale_a
571+
b_scales = scale_b
572+
573+
accumulator *= a_scales * b_scales
548574

549575
tile_id_c += NUM_SMS
550-
pid_m, pid_n = _compute_pid(
551-
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
552-
)
576+
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M)
553577
offs_cm = pid_m * BLOCK_SIZE_M
554578
offs_cn = pid_n * BLOCK_SIZE_N
555579

0 commit comments

Comments
 (0)