Skip to content

Commit edab96f

Browse files
jananisriramfacebook-github-bot
authored andcommitted
[Inductor][Triton][FP8] Add a Blackwell-specific scaled persistent + TMA template for GEMMs (#163147)
Summary: X-link: meta-pytorch/tritonbench#432 Add a Blackwell-specific scaled persistent + TMA Triton template to Inductor. This diff builds on D82515450 by adding a new set of mixins which inherit the scaling epilogue and add scaled persistent + TMA kwargs to the template. Note that this diff is a minimal extension to the above diff; rather than adding a new kernel for the scaled version, we opted to simply extend the epilogue to account for scaling. This template is accurate for per-tensor and per-row scaling but may require modifications for other scaling modes, such as deepseek-style scaling, which apply scaling prior to the GEMM computation. In addition, note that epilogue subtiling is currently unsupported for both the scaled and non-scaled Blackwell templates, and functionality will be added in a subsequent diff. Test Plan: Verified that the scaled Blackwell template adds the scaling epilogue to the generated Triton kernel by inspecting the Inductor-generated Triton kernel. Reviewed By: njriasan Differential Revision: D82597111
1 parent ddc56f6 commit edab96f

File tree

3 files changed

+173
-0
lines changed

3 files changed

+173
-0
lines changed

test/inductor/test_max_autotune.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,130 @@ def mm(a, b):
439439
with config.patch({"max_autotune": True}):
440440
torch.compile(mm, dynamic=dynamic)(a, b)
441441

442+
@unittest.skipIf(
443+
not has_datacenter_blackwell_tma_device(),
444+
"Need Blackwell with device-side TMA support in Triton",
445+
)
446+
@parametrize("dynamic", (False, True))
447+
@parametrize("tma_store", (False, True))
448+
def test_blackwell_max_autotune_scaled_mm_per_tensor_persistent_tma(
449+
self,
450+
dynamic: bool,
451+
tma_store: bool,
452+
):
453+
def scaled_mm(a, b, scale_a, scale_b):
454+
# Note that Inductor constrains a to be row_major and b to be col_major
455+
return torch._scaled_mm(
456+
a, b.t(), scale_a, scale_b, use_fast_accum=True, out_dtype=torch.float16
457+
)
458+
459+
def get_scale_per_tensor(t):
460+
scale = torch.finfo(torch.float8_e4m3fn).max / t.abs().max()
461+
return scale.to(torch.float32)
462+
463+
# TMA requires 16-byte alignment: here we repeat the dims
464+
# by the factor of 8, as float16 is 2-byte.
465+
M, N, K = 32, 16, 48
466+
a = (torch.randn((M, K)).to(torch.float16).to(GPU_TYPE)).repeat(8, 8)
467+
b = (torch.randn((N, K)).to(torch.float16).to(GPU_TYPE)).repeat(8, 8)
468+
469+
scale_a = get_scale_per_tensor(a)
470+
scale_b = get_scale_per_tensor(b)
471+
472+
a = a.to(torch.float8_e4m3fn)
473+
b = b.to(torch.float8_e4m3fn)
474+
475+
with config.patch(
476+
{
477+
"max_autotune": True,
478+
"triton.enable_persistent_tma_matmul": True,
479+
"triton.enable_template_tma_store": tma_store,
480+
"test_configs.autotune_choice_name_regex": "blackwell_ws_persistent_device_tma",
481+
}
482+
):
483+
c_actual, code = run_and_get_code(
484+
torch.compile(scaled_mm, dynamic=dynamic), a, b, scale_a, scale_b
485+
)
486+
c_expected = scaled_mm(a, b, scale_a, scale_b)
487+
488+
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=0.5)
489+
if tma_store:
490+
# Verify that we are using a TMA implementation
491+
# Note: The tma_descriptor0 is generated by the kernel. If the
492+
# code generation process changes this could change.
493+
write_api = "tma_descriptor0.store"
494+
else:
495+
write_api = "tl.store"
496+
FileCheck().check("triton_tem_fused__scaled_mm").check(
497+
"triton.language.make_tensor_descriptor"
498+
).check("tl.load_tensor_descriptor").check(write_api).run(code[0])
499+
500+
@unittest.skipIf(
501+
not has_datacenter_blackwell_tma_device(),
502+
"Need Blackwell with device-side TMA support in Triton",
503+
)
504+
@parametrize("dynamic", (False, True))
505+
@parametrize("tma_store", (False, True))
506+
def test_blackwell_max_autotune_scaled_mm_per_row_persistent_tma(
507+
self,
508+
dynamic: bool,
509+
tma_store: bool,
510+
):
511+
def scaled_mm(a, b, scale_a, scale_b):
512+
# Note that Inductor constrains a to be row_major and b to be col_majo
513+
return torch._scaled_mm(
514+
a,
515+
b.t(),
516+
scale_a,
517+
scale_b.t(),
518+
use_fast_accum=True,
519+
out_dtype=torch.bfloat16,
520+
)
521+
522+
def get_scale_per_row(t):
523+
scale = (
524+
torch.finfo(torch.float8_e4m3fn).max
525+
/ t.abs().max(dim=1, keepdim=True).values
526+
)
527+
return scale.to(torch.float32)
528+
529+
# TMA requires 16-byte alignment: here we repeat the dims
530+
# by the factor of 8, as float16 is 2-byte.
531+
M, N, K = 32, 16, 48
532+
a = (torch.randn((M, K)).to(torch.bfloat16).to(GPU_TYPE)).repeat(8, 8)
533+
b = (torch.randn((N, K)).to(torch.bfloat16).to(GPU_TYPE)).repeat(8, 8)
534+
535+
scale_a = get_scale_per_row(a)
536+
scale_b = get_scale_per_row(b)
537+
538+
a = a.to(torch.float8_e4m3fn)
539+
b = b.to(torch.float8_e4m3fn)
540+
541+
with config.patch(
542+
{
543+
"max_autotune": True,
544+
"triton.enable_persistent_tma_matmul": True,
545+
"triton.enable_template_tma_store": tma_store,
546+
"test_configs.autotune_choice_name_regex": "blackwell_ws_persistent_device_tma",
547+
}
548+
):
549+
c_actual, code = run_and_get_code(
550+
torch.compile(scaled_mm, dynamic=dynamic), a, b, scale_a, scale_b
551+
)
552+
c_expected = scaled_mm(a, b, scale_a, scale_b)
553+
554+
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=0.5)
555+
if tma_store:
556+
# Verify that we are using a TMA implementation
557+
# Note: The tma_descriptor0 is generated by the kernel. If the
558+
# code generation process changes this could change.
559+
write_api = "tma_descriptor0.store"
560+
else:
561+
write_api = "tl.store"
562+
FileCheck().check("triton_tem_fused__scaled_mm").check(
563+
"triton.language.make_tensor_descriptor"
564+
).check("tl.load_tensor_descriptor").check(write_api).run(code[0])
565+
442566
@unittest.skipIf(
443567
not has_triton_tma_device(), "Need device-side TMA support in Triton"
444568
)

torch/_inductor/kernel/mm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,15 @@ def tuned_scaled_mm(
12711271
templates_to_use.append(scaled_mm_device_tma_template)
12721272
kwarg_overrides[scaled_mm_device_tma_template.uid] = overriders
12731273

1274+
if (
1275+
use_triton_blackwell_tma_template(mat_a, mat_b, output_layout=layout)
1276+
and not bias
1277+
):
1278+
templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template)
1279+
kwarg_overrides[blackwell_ws_persistent_device_tma_mm_template.uid] = (
1280+
overriders
1281+
)
1282+
12741283
templates_to_use.append(mm_template)
12751284
kwarg_overrides[mm_template.uid] = overriders
12761285

torch/_inductor/template_heuristics/triton.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,6 +1858,29 @@ def _get_template_configs_impl(
18581858
yield template_kwargs
18591859

18601860

1861+
# Scaled Blackwell TMA-specific mixin for scaled MM templates with TMA
1862+
class ScaledBlackwellTMAConfigMixin(
1863+
BlackwellTMATemplateConfigMixin, ScaledMMConfigMixin
1864+
):
1865+
"""
1866+
Scaled Blackwell TMA-specific mixin that extends ScaledMMConfigMixin with TMA functionality.
1867+
This is for scaled MM templates that use device TMA on Blackwell.
1868+
This inherits from ScaledMMConfigMixin, which inherits the scale_mm_epilogue, and adds TMA-specific options.
1869+
"""
1870+
1871+
def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]:
1872+
"""
1873+
Warp specialization-specific filtering (BlackwellTMATemplateConfigMixin):
1874+
- num_warps < 4 not safe for warpspec due to compilation issues
1875+
- num_stages < 2 not safe for warpspec due to compilation issues
1876+
1877+
TMA-specific filtering:
1878+
- block_k >= 32 required for TMA (requires inner-most dimension >= 32)
1879+
"""
1880+
configs = [c for c in configs if c.block_k >= 32]
1881+
return super()._filter_configs(configs)
1882+
1883+
18611884
# Template-specific heuristic classes using multiple inheritance
18621885

18631886

@@ -1992,6 +2015,23 @@ def __init__(self) -> None:
19922015
self.mm_configs = self.scaled_persistent_mm_configs
19932016

19942017

2018+
@register_template_heuristic(
2019+
blackwell_ws_persistent_device_tma_mm_template.uid, # regular Blackwell MM template + scaling epilogue from ScaledMMConfigMixin
2020+
"cuda",
2021+
register=torch.version.hip is None,
2022+
)
2023+
class CUDAScaledBlackwellTMATemplateConfigHeuristic(
2024+
ScaledBlackwellTMAConfigMixin, CUDAConfigHeuristic
2025+
):
2026+
"""Scaled Blackwell TMA template heuristic for CUDA"""
2027+
2028+
def __init__(self) -> None:
2029+
super().__init__()
2030+
# Override mm_configs to use scaled_persistent_mm_configs for TMA
2031+
# TODO: Tune scaled_persistent_mm_configs for Blackwell
2032+
self.mm_configs = self.scaled_persistent_mm_configs
2033+
2034+
19952035
@register_template_heuristic(
19962036
mm_plus_mm_template.uid,
19972037
"cuda",

0 commit comments

Comments
 (0)