|
1 | 1 | from functools import lru_cache
|
| 2 | +from typing import Optional |
2 | 3 |
|
3 | 4 | import torch
|
4 | 5 | import triton
|
5 | 6 | import triton.language as tl
|
6 |
| -import triton.tools.experimental_descriptor |
7 | 7 |
|
8 | 8 | from tritonbench.utils.env_utils import is_cuda
|
| 9 | +from tritonbench.utils.triton_utils import has_experimental_descriptor |
| 10 | + |
| 11 | +if has_experimental_descriptor(): |
| 12 | + import triton.tools.experimental_descriptor |
9 | 13 |
|
10 | 14 | cublas = None
|
11 | 15 | if is_cuda():
|
@@ -289,6 +293,36 @@ def matmul_configs():
|
289 | 293 | }
|
290 | 294 |
|
291 | 295 |
|
| 296 | +def matmul_configs_blackwell(): |
| 297 | + # Autotuner does not work with TMA. Use manual config. |
| 298 | + return { |
| 299 | + torch.float8_e4m3fn: { |
| 300 | + "BLOCK_SIZE_M": 128, |
| 301 | + "BLOCK_SIZE_N": 128, |
| 302 | + "BLOCK_SIZE_K": 128, |
| 303 | + "GROUP_SIZE_M": 8, |
| 304 | + "num_stages": 4, |
| 305 | + "num_warps": 4, # Note: num_warps >= 4 required for TMA |
| 306 | + }, |
| 307 | + torch.float16: { |
| 308 | + "BLOCK_SIZE_M": 128, |
| 309 | + "BLOCK_SIZE_N": 256, |
| 310 | + "BLOCK_SIZE_K": 64, |
| 311 | + "GROUP_SIZE_M": 8, |
| 312 | + "num_stages": 2, |
| 313 | + "num_warps": 2, |
| 314 | + }, |
| 315 | + torch.bfloat16: { |
| 316 | + "BLOCK_SIZE_M": 128, |
| 317 | + "BLOCK_SIZE_N": 256, |
| 318 | + "BLOCK_SIZE_K": 64, |
| 319 | + "GROUP_SIZE_M": 8, |
| 320 | + "num_stages": 2, |
| 321 | + "num_warps": 2, |
| 322 | + }, |
| 323 | + } |
| 324 | + |
| 325 | + |
292 | 326 | def allocate_matmul_tma(a, b):
|
293 | 327 | configs = matmul_configs()
|
294 | 328 |
|
@@ -364,3 +398,160 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
|
364 | 398 | num_warps=configs[dtype]["num_warps"], #
|
365 | 399 | )
|
366 | 400 | return c
|
| 401 | + |
| 402 | + |
| 403 | +# Blackwell Persistent + TMA |
| 404 | +# Restrictions: |
| 405 | +# - (K, N) must be a multiple of 16 on B200 for all benchmarks |
| 406 | +# - num_warps >= 4 |
| 407 | +# - TMA instructions expect at least a 128-thread group |
| 408 | +# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps |
| 409 | + |
| 410 | + |
| 411 | +def blackwell_persistent_tma(a, b, scale_a, scale_b, acc_dtype): |
| 412 | + configs = matmul_configs_blackwell() |
| 413 | + |
| 414 | + # Check constraints. |
| 415 | + assert ( |
| 416 | + a.shape[1] == b.shape[1] |
| 417 | + ), "Incompatible dimensions" # a.shape = (M, K), b.shape = (N, K) |
| 418 | + assert a.dtype == b.dtype, "Incompatible dtypes" |
| 419 | + |
| 420 | + M, K = a.shape |
| 421 | + N, K = b.shape |
| 422 | + shape_dtype = a.dtype # low-precision dtype, e.g. fp8 |
| 423 | + |
| 424 | + NUM_SMS = torch.cuda.get_device_properties( |
| 425 | + torch.cuda.current_device() |
| 426 | + ).multi_processor_count |
| 427 | + |
| 428 | + c = torch.zeros((M, N), device=a.device, dtype=acc_dtype) |
| 429 | + |
| 430 | + def alloc_fn(size: int, align: int, stream: Optional[int]): |
| 431 | + return torch.empty(size, dtype=torch.int8, device=a.device) |
| 432 | + |
| 433 | + if hasattr(triton, "set_allocator"): |
| 434 | + triton.set_allocator(alloc_fn) |
| 435 | + else: |
| 436 | + return c |
| 437 | + |
| 438 | + if acc_dtype == torch.float16: |
| 439 | + acc_dtype_tl = tl.float16 |
| 440 | + elif acc_dtype == torch.bfloat16: |
| 441 | + acc_dtype_tl = tl.bfloat16 |
| 442 | + else: |
| 443 | + raise NotImplementedError( |
| 444 | + "Output types other than torch.float16 and torch.bfloat16 unsupported for FP8 Blackwell persistent + TMA kernels" |
| 445 | + ) |
| 446 | + |
| 447 | + grid = lambda META: ( |
| 448 | + min( |
| 449 | + NUM_SMS, |
| 450 | + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), |
| 451 | + ), |
| 452 | + ) |
| 453 | + blackwell_persistent_tma_kernel[grid]( |
| 454 | + a, |
| 455 | + b, |
| 456 | + c, # |
| 457 | + M, |
| 458 | + N, |
| 459 | + K, # |
| 460 | + scale_a.item(), # |
| 461 | + scale_b.item(), # |
| 462 | + BLOCK_SIZE_M=configs[shape_dtype]["BLOCK_SIZE_M"], # |
| 463 | + BLOCK_SIZE_N=configs[shape_dtype]["BLOCK_SIZE_N"], # |
| 464 | + BLOCK_SIZE_K=configs[shape_dtype]["BLOCK_SIZE_K"], # |
| 465 | + GROUP_SIZE_M=configs[shape_dtype]["GROUP_SIZE_M"], # |
| 466 | + ACC_TYPE=acc_dtype_tl, |
| 467 | + NUM_SMS=NUM_SMS, # |
| 468 | + num_stages=configs[shape_dtype]["num_stages"], # |
| 469 | + num_warps=configs[shape_dtype]["num_warps"], # |
| 470 | + ) |
| 471 | + return c |
| 472 | + |
| 473 | + |
| 474 | +@triton.jit |
| 475 | +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): |
| 476 | + group_id = tile_id // num_pid_in_group |
| 477 | + first_pid_m = group_id * GROUP_SIZE_M |
| 478 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 479 | + pid_m = first_pid_m + (tile_id % group_size_m) |
| 480 | + pid_n = (tile_id % num_pid_in_group) // group_size_m |
| 481 | + return pid_m, pid_n |
| 482 | + |
| 483 | + |
| 484 | +@triton.jit(launch_metadata=_matmul_launch_metadata) |
| 485 | +def blackwell_persistent_tma_kernel( |
| 486 | + a, |
| 487 | + b, |
| 488 | + acc, |
| 489 | + M, # |
| 490 | + N, # |
| 491 | + K, # |
| 492 | + scale_a, # |
| 493 | + scale_b, # |
| 494 | + BLOCK_SIZE_M: tl.constexpr, # |
| 495 | + BLOCK_SIZE_N: tl.constexpr, # |
| 496 | + BLOCK_SIZE_K: tl.constexpr, # |
| 497 | + GROUP_SIZE_M: tl.constexpr, # |
| 498 | + ACC_TYPE: tl.constexpr, |
| 499 | + NUM_SMS: tl.constexpr, |
| 500 | +): # |
| 501 | + start_pid = tl.program_id(axis=0) |
| 502 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 503 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 504 | + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) |
| 505 | + num_tiles = num_pid_m * num_pid_n |
| 506 | + |
| 507 | + a_desc = tl.make_tensor_descriptor( |
| 508 | + a, |
| 509 | + shape=[M, K], |
| 510 | + strides=[K, 1], |
| 511 | + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], |
| 512 | + ) |
| 513 | + b_desc = tl.make_tensor_descriptor( |
| 514 | + b, |
| 515 | + shape=[N, K], |
| 516 | + strides=[K, 1], |
| 517 | + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], |
| 518 | + ) |
| 519 | + acc_desc = tl.make_tensor_descriptor( |
| 520 | + acc, |
| 521 | + shape=[M, N], |
| 522 | + strides=[N, 1], |
| 523 | + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], |
| 524 | + ) |
| 525 | + |
| 526 | + tile_id_c = start_pid - NUM_SMS |
| 527 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 528 | + |
| 529 | + for tile_id in tl.range( |
| 530 | + start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True |
| 531 | + ): |
| 532 | + pid_m, pid_n = _compute_pid( |
| 533 | + tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS |
| 534 | + ) |
| 535 | + offs_am = pid_m * BLOCK_SIZE_M |
| 536 | + offs_bn = pid_n * BLOCK_SIZE_N |
| 537 | + |
| 538 | + accumulator = tl.zeros( |
| 539 | + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 |
| 540 | + ) # accumulate in high precision (fp32) for high accuracy |
| 541 | + for ki in range(k_tiles): |
| 542 | + offs_k = ki * BLOCK_SIZE_K |
| 543 | + a_block = a_desc.load([offs_am, offs_k]) |
| 544 | + b_block = b_desc.load([offs_bn, offs_k]) |
| 545 | + accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32) |
| 546 | + |
| 547 | + accumulator *= scale_a * scale_b # currently only supports per-tensor scaling |
| 548 | + |
| 549 | + tile_id_c += NUM_SMS |
| 550 | + pid_m, pid_n = _compute_pid( |
| 551 | + tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS |
| 552 | + ) |
| 553 | + offs_cm = pid_m * BLOCK_SIZE_M |
| 554 | + offs_cn = pid_n * BLOCK_SIZE_N |
| 555 | + |
| 556 | + c = accumulator.to(ACC_TYPE) |
| 557 | + acc_desc.store([offs_cm, offs_cn], c) |
0 commit comments