You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Inductor][Triton][FP8] Add a Blackwell-specific scaled persistent + TMA template for GEMMs (pytorch#163147)
Summary:
X-link: meta-pytorch/tritonbench#432
Add a Blackwell-specific scaled persistent + TMA Triton template to Inductor. This diff builds on D82515450 by adding a new set of mixins which inherit the scaling epilogue and add scaled persistent + TMA kwargs to the template.
This diff also adds a benchmark for the scaled Blackwell persistent + TMA template to TritonBench `fp8_gemm`.
Note that this diff is a minimal extension to the above diff; rather than adding a new kernel for the scaled version, we opted to simply extend the epilogue to account for scaling. This template is accurate for per-tensor and per-row scaling but may require modifications for other scaling modes, such as deepseek-style scaling, which apply scaling prior to the GEMM computation.
In addition, note that epilogue subtiling is currently unsupported for both the scaled and non-scaled Blackwell templates, and functionality will be added in a subsequent diff.
Test Plan:
Verified that the scaled Blackwell template adds the scaling epilogue to the generated Triton kernel by inspecting the Inductor-generated Triton kernel.
Benchmarking command:
```
TRITON_PRINT_AUTOTUNING=1 TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor TRITON_CACHE_DIR=~/personal/cache_dir_triton TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -- --op fp8_gemm --only torch_fp8_gemm,blackwell_pt2_fp8_gemm --metrics tflops,accuracy --input-loader=/home/jananisriram/personal/fp8_shapes_testing.json --scaling_rowwise --output="/home/jananisriram/personal/fp8_shapes_testing_results.csv" --atol=1e-2 --rtol=0.5 2>&1 | tee ~/personal/fp8_shapes_testing.log
```
Rollback Plan:
Differential Revision: D82597111
Pull Request resolved: pytorch#163147
Approved by: https://github.com/njriasan
0 commit comments