You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Inductor][Triton][FP8] Add a Blackwell-specific scaled persistent + TMA template for GEMMs (#163147)
Summary:
X-link: meta-pytorch/tritonbench#432
Add a Blackwell-specific scaled persistent + TMA Triton template to Inductor. This diff builds on D82515450 by adding a new set of mixins which inherit the scaling epilogue and add scaled persistent + TMA kwargs to the template.
Note that this diff is a minimal extension to the above diff; rather than adding a new kernel for the scaled version, we opted to simply extend the epilogue to account for scaling. This template is accurate for per-tensor and per-row scaling but may require modifications for other scaling modes, such as deepseek-style scaling, which apply scaling prior to the GEMM computation.
In addition, note that epilogue subtiling is currently unsupported for both the scaled and non-scaled Blackwell templates, and functionality will be added in a subsequent diff.
Test Plan: Verified that the scaled Blackwell template adds the scaling epilogue to the generated Triton kernel by inspecting the Inductor-generated Triton kernel.
Reviewed By: njriasan
Differential Revision: D82597111
0 commit comments