Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tritonbench/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,12 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
@register_benchmark(enabled=True)
def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
return lambda: blackwell_persistent_tma(
a, b.T, scale_a, scale_b.T, self._get_dtype()
a,
b.T,
scale_a,
scale_b.T,
self._get_dtype(),
self.extra_args.scaling_rowwise,
)

@register_benchmark()
Expand Down
46 changes: 35 additions & 11 deletions tritonbench/operators/fp8_gemm/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps


def blackwell_persistent_tma(a, b, scale_a, scale_b, acc_dtype):
def blackwell_persistent_tma(
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
):
configs = matmul_configs_blackwell()

# Check constraints.
Expand Down Expand Up @@ -457,8 +459,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
M,
N,
K, #
scale_a.item(), #
scale_b.item(), #
scale_a_ptr, #
scale_b_ptr, #
BLOCK_SIZE_M=configs[shape_dtype]["BLOCK_SIZE_M"], #
BLOCK_SIZE_N=configs[shape_dtype]["BLOCK_SIZE_N"], #
BLOCK_SIZE_K=configs[shape_dtype]["BLOCK_SIZE_K"], #
Expand All @@ -467,12 +469,13 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
NUM_SMS=NUM_SMS, #
num_stages=configs[shape_dtype]["num_stages"], #
num_warps=configs[shape_dtype]["num_warps"], #
SCALING_ROWWISE=scaling_rowwise,
)
return c


@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
Expand All @@ -489,14 +492,15 @@ def blackwell_persistent_tma_kernel(
M, #
N, #
K, #
scale_a, #
scale_b, #
scale_a_ptr, #
scale_b_ptr, #
BLOCK_SIZE_M: tl.constexpr, #
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
ACC_TYPE: tl.constexpr,
NUM_SMS: tl.constexpr,
SCALING_ROWWISE: tl.constexpr, #
): #
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand Down Expand Up @@ -526,12 +530,19 @@ def blackwell_persistent_tma_kernel(
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = GROUP_SIZE_M * num_pid_n

if SCALING_ROWWISE:
# For row-wise scaling, we'll use the pointers as-is
scale_a = scale_a_ptr
scale_b = scale_b_ptr
else:
# For per-tensor scaling, we'll load the scalar values
scale_a = tl.load(scale_a_ptr)
scale_b = tl.load(scale_b_ptr)

for tile_id in tl.range(
start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True
):
pid_m, pid_n = _compute_pid(
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M)
offs_am = pid_m * BLOCK_SIZE_M
offs_bn = pid_n * BLOCK_SIZE_N

Expand All @@ -544,11 +555,24 @@ def blackwell_persistent_tma_kernel(
b_block = b_desc.load([offs_bn, offs_k])
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)

accumulator *= scale_a * scale_b # currently only supports per-tensor scaling
if SCALING_ROWWISE:
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)

scale_a_block = tl.load(scale_a + offs_scale_m, mask=offs_am < M, other=0.0)
scale_b_block = tl.load(scale_b + offs_scale_n, mask=offs_bn < N, other=0.0)

a_scales = scale_a_block[:, None]
b_scales = scale_b_block[None, :]
else:
a_scales = scale_a
b_scales = scale_b

accumulator *= a_scales * b_scales

tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M
)
offs_cm = pid_m * BLOCK_SIZE_M
offs_cn = pid_n * BLOCK_SIZE_N
Expand Down