Skip to content

Commit 225789c

Browse files
jananisriramcleonard530
authored andcommitted
[Inductor][Triton][FP8] Add a Blackwell-specific scaled persistent + TMA template for GEMMs (pytorch#163147)
Summary: X-link: meta-pytorch/tritonbench#432 Add a Blackwell-specific scaled persistent + TMA Triton template to Inductor. This diff builds on D82515450 by adding a new set of mixins which inherit the scaling epilogue and add scaled persistent + TMA kwargs to the template. This diff also adds a benchmark for the scaled Blackwell persistent + TMA template to TritonBench `fp8_gemm`. Note that this diff is a minimal extension to the above diff; rather than adding a new kernel for the scaled version, we opted to simply extend the epilogue to account for scaling. This template is accurate for per-tensor and per-row scaling but may require modifications for other scaling modes, such as deepseek-style scaling, which apply scaling prior to the GEMM computation. In addition, note that epilogue subtiling is currently unsupported for both the scaled and non-scaled Blackwell templates, and functionality will be added in a subsequent diff. Test Plan: Verified that the scaled Blackwell template adds the scaling epilogue to the generated Triton kernel by inspecting the Inductor-generated Triton kernel. Benchmarking command: ``` TRITON_PRINT_AUTOTUNING=1 TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor TRITON_CACHE_DIR=~/personal/cache_dir_triton TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -- --op fp8_gemm --only torch_fp8_gemm,blackwell_pt2_fp8_gemm --metrics tflops,accuracy --input-loader=/home/jananisriram/personal/fp8_shapes_testing.json --scaling_rowwise --output="/home/jananisriram/personal/fp8_shapes_testing_results.csv" --atol=1e-2 --rtol=0.5 2>&1 | tee ~/personal/fp8_shapes_testing.log ``` Rollback Plan: Differential Revision: D82597111 Pull Request resolved: pytorch#163147 Approved by: https://github.com/njriasan
1 parent c6bb6b8 commit 225789c

File tree

3 files changed

+176
-0
lines changed

3 files changed

+176
-0
lines changed

test/inductor/test_max_autotune.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,132 @@ def mm(a, b):
439439
with config.patch({"max_autotune": True}):
440440
torch.compile(mm, dynamic=dynamic)(a, b)
441441

442+
# NOTE: the current Inductor template verifies that the scaling mode is either per-tensor or per-row
443+
# TODO: support additional scaling modes for Blackwell
444+
@unittest.skipIf(
445+
not has_datacenter_blackwell_tma_device(),
446+
"Need Blackwell with device-side TMA support in Triton",
447+
)
448+
@parametrize("dynamic", (False, True))
449+
@parametrize("tma_store", (False, True))
450+
def test_blackwell_max_autotune_scaled_mm_per_tensor_persistent_tma(
451+
self,
452+
dynamic: bool,
453+
tma_store: bool,
454+
):
455+
def scaled_mm(a, b, scale_a, scale_b):
456+
# NOTE: Inductor constrains a to be row_major and b to be col_major
457+
return torch._scaled_mm(
458+
a, b.t(), scale_a, scale_b, use_fast_accum=True, out_dtype=torch.float16
459+
)
460+
461+
def get_scale_per_tensor(t):
462+
scale = torch.finfo(torch.float8_e4m3fn).max / t.abs().max()
463+
return scale.to(torch.float32)
464+
465+
# TMA requires 16-byte alignment: here we repeat the dims
466+
# by the factor of 8, as float16 is 2-byte.
467+
M, N, K = 32, 16, 48
468+
a = (torch.randn((M, K)).to(torch.float16).to(GPU_TYPE)).repeat(8, 8)
469+
b = (torch.randn((N, K)).to(torch.float16).to(GPU_TYPE)).repeat(8, 8)
470+
471+
scale_a = get_scale_per_tensor(a)
472+
scale_b = get_scale_per_tensor(b)
473+
474+
a = a.to(torch.float8_e4m3fn)
475+
b = b.to(torch.float8_e4m3fn)
476+
477+
with config.patch(
478+
{
479+
"max_autotune": True,
480+
"triton.enable_persistent_tma_matmul": True,
481+
"triton.enable_template_tma_store": tma_store,
482+
"test_configs.autotune_choice_name_regex": "blackwell_ws_persistent_device_tma",
483+
}
484+
):
485+
c_actual, code = run_and_get_code(
486+
torch.compile(scaled_mm, dynamic=dynamic), a, b, scale_a, scale_b
487+
)
488+
c_expected = scaled_mm(a, b, scale_a, scale_b)
489+
490+
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=0.5)
491+
if tma_store:
492+
# Verify that we are using a TMA implementation
493+
# Note: The tma_descriptor0 is generated by the kernel. If the
494+
# code generation process changes this could change.
495+
write_api = "tma_descriptor0.store"
496+
else:
497+
write_api = "tl.store"
498+
FileCheck().check("triton_tem_fused__scaled_mm").check(
499+
"triton.language.make_tensor_descriptor"
500+
).check("tl.load_tensor_descriptor").check(write_api).run(code[0])
501+
502+
@unittest.skipIf(
503+
not has_datacenter_blackwell_tma_device(),
504+
"Need Blackwell with device-side TMA support in Triton",
505+
)
506+
@parametrize("dynamic", (False, True))
507+
@parametrize("tma_store", (False, True))
508+
def test_blackwell_max_autotune_scaled_mm_per_row_persistent_tma(
509+
self,
510+
dynamic: bool,
511+
tma_store: bool,
512+
):
513+
def scaled_mm(a, b, scale_a, scale_b):
514+
# NOTE: Inductor constrains a to be row_major and b to be col_majo
515+
return torch._scaled_mm(
516+
a,
517+
b.t(),
518+
scale_a,
519+
scale_b.t(),
520+
use_fast_accum=True,
521+
out_dtype=torch.bfloat16,
522+
)
523+
524+
def get_scale_per_row(t):
525+
scale = (
526+
torch.finfo(torch.float8_e4m3fn).max
527+
/ t.abs().max(dim=1, keepdim=True).values
528+
)
529+
return scale.to(torch.float32)
530+
531+
# TMA requires 16-byte alignment: here we repeat the dims
532+
# by the factor of 8, as float16 is 2-byte.
533+
M, N, K = 32, 16, 48
534+
a = (torch.randn((M, K)).to(torch.bfloat16).to(GPU_TYPE)).repeat(8, 8)
535+
b = (torch.randn((N, K)).to(torch.bfloat16).to(GPU_TYPE)).repeat(8, 8)
536+
537+
scale_a = get_scale_per_row(a)
538+
scale_b = get_scale_per_row(b)
539+
540+
a = a.to(torch.float8_e4m3fn)
541+
b = b.to(torch.float8_e4m3fn)
542+
543+
with config.patch(
544+
{
545+
"max_autotune": True,
546+
"triton.enable_persistent_tma_matmul": True,
547+
"triton.enable_template_tma_store": tma_store,
548+
"test_configs.autotune_choice_name_regex": "blackwell_ws_persistent_device_tma",
549+
}
550+
):
551+
c_actual, code = run_and_get_code(
552+
torch.compile(scaled_mm, dynamic=dynamic), a, b, scale_a, scale_b
553+
)
554+
c_expected = scaled_mm(a, b, scale_a, scale_b)
555+
556+
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=0.5)
557+
if tma_store:
558+
# Verify that we are using a TMA implementation
559+
# Note: The tma_descriptor0 is generated by the kernel. If the
560+
# code generation process changes this could change.
561+
write_api = "tma_descriptor0.store"
562+
else:
563+
write_api = "tl.store"
564+
FileCheck().check("triton_tem_fused__scaled_mm").check(
565+
"triton.language.make_tensor_descriptor"
566+
).check("tl.load_tensor_descriptor").check(write_api).run(code[0])
567+
442568
@unittest.skipIf(
443569
not has_triton_tma_device(), "Need device-side TMA support in Triton"
444570
)

torch/_inductor/kernel/mm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,15 @@ def tuned_scaled_mm(
12711271
templates_to_use.append(scaled_mm_device_tma_template)
12721272
kwarg_overrides[scaled_mm_device_tma_template.uid] = overriders
12731273

1274+
if (
1275+
use_triton_blackwell_tma_template(mat_a, mat_b, output_layout=layout)
1276+
and not bias
1277+
):
1278+
templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template)
1279+
kwarg_overrides[blackwell_ws_persistent_device_tma_mm_template.uid] = (
1280+
overriders
1281+
)
1282+
12741283
templates_to_use.append(mm_template)
12751284
kwarg_overrides[mm_template.uid] = overriders
12761285

torch/_inductor/template_heuristics/triton.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,6 +1944,30 @@ def _get_template_configs_impl(
19441944
yield template_kwargs
19451945

19461946

1947+
# Scaled Blackwell TMA-specific mixin for scaled MM templates with TMA
1948+
class ScaledBlackwellTMAConfigMixin(
1949+
BlackwellTMATemplateConfigMixin, ScaledMMConfigMixin
1950+
):
1951+
"""
1952+
Scaled Blackwell TMA-specific mixin that extends ScaledMMConfigMixin with TMA functionality.
1953+
This is for scaled MM templates that use device TMA on Blackwell.
1954+
This inherits from ScaledMMConfigMixin, which inherits the scale_mm_epilogue, and adds TMA-specific options.
1955+
"""
1956+
1957+
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
1958+
"""
1959+
Warp specialization-specific filtering (BlackwellTMATemplateConfigMixin)
1960+
(compilation issues occur in some versions of Triton)
1961+
- num_warps < 4 unsafe for warpspec
1962+
- num_stages < 2 unsafe for warpspec
1963+
1964+
TMA-specific filtering:
1965+
- block_k >= 32 required for TMA (requires inner-most dimension >= 32)
1966+
"""
1967+
configs = [c for c in configs if c.block_k >= 32]
1968+
return super()._filter_configs(configs)
1969+
1970+
19471971
# Template-specific heuristic classes using multiple inheritance
19481972

19491973

@@ -2078,6 +2102,23 @@ def __init__(self) -> None:
20782102
self.mm_configs = self.scaled_persistent_mm_configs
20792103

20802104

2105+
@register_template_heuristic(
2106+
blackwell_ws_persistent_device_tma_mm_template.uid, # regular Blackwell MM template + scaling epilogue from ScaledMMConfigMixin
2107+
"cuda",
2108+
register=torch.version.hip is None,
2109+
)
2110+
class CUDAScaledBlackwellTMATemplateConfigHeuristic(
2111+
ScaledBlackwellTMAConfigMixin, CUDAConfigHeuristic
2112+
):
2113+
"""Scaled Blackwell TMA template heuristic for CUDA"""
2114+
2115+
def __init__(self) -> None:
2116+
super().__init__()
2117+
# Override mm_configs to use scaled_persistent_mm_configs for TMA
2118+
# TODO: Tune scaled_persistent_mm_configs for Blackwell
2119+
self.mm_configs = self.scaled_persistent_mm_configs
2120+
2121+
20812122
@register_template_heuristic(
20822123
mm_plus_mm_template.uid,
20832124
"cuda",

0 commit comments

Comments
 (0)