@@ -408,7 +408,9 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
408
408
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
409
409
410
410
411
- def blackwell_persistent_tma (a , b , scale_a , scale_b , acc_dtype ):
411
+ def blackwell_persistent_tma (
412
+ a , b , scale_a_ptr , scale_b_ptr , acc_dtype , scaling_rowwise
413
+ ):
412
414
configs = matmul_configs_blackwell ()
413
415
414
416
# Check constraints.
@@ -457,8 +459,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
457
459
M ,
458
460
N ,
459
461
K , #
460
- scale_a . item () , #
461
- scale_b . item () , #
462
+ scale_a_ptr , #
463
+ scale_b_ptr , #
462
464
BLOCK_SIZE_M = configs [shape_dtype ]["BLOCK_SIZE_M" ], #
463
465
BLOCK_SIZE_N = configs [shape_dtype ]["BLOCK_SIZE_N" ], #
464
466
BLOCK_SIZE_K = configs [shape_dtype ]["BLOCK_SIZE_K" ], #
@@ -467,12 +469,13 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
467
469
NUM_SMS = NUM_SMS , #
468
470
num_stages = configs [shape_dtype ]["num_stages" ], #
469
471
num_warps = configs [shape_dtype ]["num_warps" ], #
472
+ SCALING_ROWWISE = scaling_rowwise ,
470
473
)
471
474
return c
472
475
473
476
474
477
@triton .jit
475
- def _compute_pid (tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M , NUM_SMS ):
478
+ def _compute_pid (tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M ):
476
479
group_id = tile_id // num_pid_in_group
477
480
first_pid_m = group_id * GROUP_SIZE_M
478
481
group_size_m = min (num_pid_m - first_pid_m , GROUP_SIZE_M )
@@ -489,14 +492,15 @@ def blackwell_persistent_tma_kernel(
489
492
M , #
490
493
N , #
491
494
K , #
492
- scale_a , #
493
- scale_b , #
495
+ scale_a_ptr , #
496
+ scale_b_ptr , #
494
497
BLOCK_SIZE_M : tl .constexpr , #
495
498
BLOCK_SIZE_N : tl .constexpr , #
496
499
BLOCK_SIZE_K : tl .constexpr , #
497
500
GROUP_SIZE_M : tl .constexpr , #
498
501
ACC_TYPE : tl .constexpr ,
499
502
NUM_SMS : tl .constexpr ,
503
+ SCALING_ROWWISE : tl .constexpr , #
500
504
): #
501
505
start_pid = tl .program_id (axis = 0 )
502
506
num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
@@ -526,11 +530,20 @@ def blackwell_persistent_tma_kernel(
526
530
tile_id_c = start_pid - NUM_SMS
527
531
num_pid_in_group = GROUP_SIZE_M * num_pid_n
528
532
533
+ if SCALING_ROWWISE :
534
+ # For row-wise scaling, we'll use the pointers as-is
535
+ scale_a = scale_a_ptr
536
+ scale_b = scale_b_ptr
537
+ else :
538
+ # For per-tensor scaling, we'll load the scalar values
539
+ scale_a = tl .load (scale_a_ptr )
540
+ scale_b = tl .load (scale_b_ptr )
541
+
529
542
for tile_id in tl .range (
530
543
start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = True
531
544
):
532
545
pid_m , pid_n = _compute_pid (
533
- tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M , NUM_SMS
546
+ tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M
534
547
)
535
548
offs_am = pid_m * BLOCK_SIZE_M
536
549
offs_bn = pid_n * BLOCK_SIZE_N
@@ -544,12 +557,23 @@ def blackwell_persistent_tma_kernel(
544
557
b_block = b_desc .load ([offs_bn , offs_k ])
545
558
accumulator = tl .dot (a_block , b_block .T , accumulator , out_dtype = tl .float32 )
546
559
547
- accumulator *= scale_a * scale_b # currently only supports per-tensor scaling
560
+ if SCALING_ROWWISE :
561
+ offs_scale_m = offs_am + tl .arange (0 , BLOCK_SIZE_M )
562
+ offs_scale_n = offs_bn + tl .arange (0 , BLOCK_SIZE_N )
563
+
564
+ scale_a_block = tl .load (scale_a + offs_scale_m , mask = offs_am < M , other = 0.0 )
565
+ scale_b_block = tl .load (scale_b + offs_scale_n , mask = offs_bn < N , other = 0.0 )
566
+
567
+ a_scales = scale_a_block [:, None ]
568
+ b_scales = scale_b_block [None , :]
569
+ else :
570
+ a_scales = scale_a
571
+ b_scales = scale_b
572
+
573
+ accumulator *= a_scales * b_scales
548
574
549
575
tile_id_c += NUM_SMS
550
- pid_m , pid_n = _compute_pid (
551
- tile_id_c , num_pid_in_group , num_pid_m , GROUP_SIZE_M , NUM_SMS
552
- )
576
+ pid_m , pid_n = _compute_pid (tile_id_c , num_pid_in_group , num_pid_m , GROUP_SIZE_M )
553
577
offs_cm = pid_m * BLOCK_SIZE_M
554
578
offs_cn = pid_n * BLOCK_SIZE_N
555
579
0 commit comments