@@ -405,7 +405,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
405
405
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
406
406
407
407
408
- def blackwell_persistent_tma (a , b , scale_a , scale_b , acc_dtype ):
408
+ def blackwell_persistent_tma (a , b , scale_a_ptr , scale_b_ptr , acc_dtype , scaling_rowwise ):
409
409
configs = matmul_configs_blackwell ()
410
410
411
411
# Check constraints.
@@ -454,8 +454,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
454
454
M ,
455
455
N ,
456
456
K , #
457
- scale_a . item () , #
458
- scale_b . item () , #
457
+ scale_a_ptr , #
458
+ scale_b_ptr , #
459
459
BLOCK_SIZE_M = configs [shape_dtype ]["BLOCK_SIZE_M" ], #
460
460
BLOCK_SIZE_N = configs [shape_dtype ]["BLOCK_SIZE_N" ], #
461
461
BLOCK_SIZE_K = configs [shape_dtype ]["BLOCK_SIZE_K" ], #
@@ -464,12 +464,13 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
464
464
NUM_SMS = NUM_SMS , #
465
465
num_stages = configs [shape_dtype ]["num_stages" ], #
466
466
num_warps = configs [shape_dtype ]["num_warps" ], #
467
+ SCALING_ROWWISE = scaling_rowwise ,
467
468
)
468
469
return c
469
470
470
471
471
472
@triton .jit
472
- def _compute_pid (tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M , NUM_SMS ):
473
+ def _compute_pid (tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M ):
473
474
group_id = tile_id // num_pid_in_group
474
475
first_pid_m = group_id * GROUP_SIZE_M
475
476
group_size_m = min (num_pid_m - first_pid_m , GROUP_SIZE_M )
@@ -485,14 +486,15 @@ def blackwell_persistent_tma_kernel(
485
486
M , #
486
487
N , #
487
488
K , #
488
- scale_a , #
489
- scale_b , #
489
+ scale_a_ptr , #
490
+ scale_b_ptr , #
490
491
BLOCK_SIZE_M : tl .constexpr , #
491
492
BLOCK_SIZE_N : tl .constexpr , #
492
493
BLOCK_SIZE_K : tl .constexpr , #
493
494
GROUP_SIZE_M : tl .constexpr , #
494
495
ACC_TYPE : tl .constexpr ,
495
496
NUM_SMS : tl .constexpr ,
497
+ SCALING_ROWWISE : tl .constexpr , #
496
498
): #
497
499
start_pid = tl .program_id (axis = 0 )
498
500
num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
@@ -522,12 +524,17 @@ def blackwell_persistent_tma_kernel(
522
524
tile_id_c = start_pid - NUM_SMS
523
525
num_pid_in_group = GROUP_SIZE_M * num_pid_n
524
526
525
- for tile_id in tl .range (
526
- start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = True
527
- ):
528
- pid_m , pid_n = _compute_pid (
529
- tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M , NUM_SMS
530
- )
527
+ if SCALING_ROWWISE :
528
+ # For row-wise scaling, we'll use the pointers as-is
529
+ scale_a = scale_a_ptr
530
+ scale_b = scale_b_ptr
531
+ else :
532
+ # For per-tensor scaling, we'll load the scalar values
533
+ scale_a = tl .load (scale_a_ptr )
534
+ scale_b = tl .load (scale_b_ptr )
535
+
536
+ for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = True ):
537
+ pid_m , pid_n = _compute_pid (tile_id , num_pid_in_group , num_pid_m , GROUP_SIZE_M )
531
538
offs_am = pid_m * BLOCK_SIZE_M
532
539
offs_bn = pid_n * BLOCK_SIZE_N
533
540
@@ -540,12 +547,23 @@ def blackwell_persistent_tma_kernel(
540
547
b_block = b_desc .load ([offs_bn , offs_k ])
541
548
accumulator = tl .dot (a_block , b_block .T , accumulator , out_dtype = tl .float32 )
542
549
543
- accumulator *= scale_a * scale_b # currently only supports per-tensor scaling
550
+ if SCALING_ROWWISE :
551
+ offs_scale_m = offs_am + tl .arange (0 , BLOCK_SIZE_M )
552
+ offs_scale_n = offs_bn + tl .arange (0 , BLOCK_SIZE_N )
553
+
554
+ scale_a_block = tl .load (scale_a + offs_scale_m , mask = offs_am < M , other = 0.0 )
555
+ scale_b_block = tl .load (scale_b + offs_scale_n , mask = offs_bn < N , other = 0.0 )
556
+
557
+ a_scales = scale_a_block [:, None ]
558
+ b_scales = scale_b_block [None , :]
559
+ else :
560
+ a_scales = scale_a
561
+ b_scales = scale_b
562
+
563
+ accumulator *= a_scales * b_scales
544
564
545
565
tile_id_c += NUM_SMS
546
- pid_m , pid_n = _compute_pid (
547
- tile_id_c , num_pid_in_group , num_pid_m , GROUP_SIZE_M , NUM_SMS
548
- )
566
+ pid_m , pid_n = _compute_pid (tile_id_c , num_pid_in_group , num_pid_m , GROUP_SIZE_M )
549
567
offs_cm = pid_m * BLOCK_SIZE_M
550
568
offs_cn = pid_n * BLOCK_SIZE_N
551
569
0 commit comments