From 7874d9c6acfcbea118026a50152e7c9b2e3853c8 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Tue, 16 Sep 2025 23:18:36 -0700 Subject: [PATCH 1/5] Prototype persistent + TMA kernel with warp specialization (#385) Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285 --- tritonbench/operators/fp8_gemm/fp8_gemm.py | 15 ++ tritonbench/operators/fp8_gemm/persistent.py | 187 +++++++++++++++++++ 2 files changed, 202 insertions(+) diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 3a0897e97..4c1eb3b4a 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -7,6 +7,9 @@ import torch._inductor.config as inductor_config import triton +from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma +from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda + from tritonbench.utils.triton_op import ( BenchmarkOperator, BenchmarkOperatorMetrics, @@ -19,6 +22,10 @@ from .tutorial import matmul as tutorial_matmul +IS_B200 = is_cuda() and get_nvidia_gpu_model() == "NVIDIA B200" + +torch._dynamo.config.recompile_limit = 10000 + logger = logging.getLogger(__name__) try: from .persistent import ( @@ -169,6 +176,14 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable: return lambda: compiled(a, b) + if IS_B200: + + @register_benchmark(enabled=True) + def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b): + return lambda: blackwell_persistent_tma( + a, b.T, scale_a, scale_b.T, self._get_dtype() + ) + @register_benchmark() def triton_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: tutorial_matmul(a, b) diff --git a/tritonbench/operators/fp8_gemm/persistent.py b/tritonbench/operators/fp8_gemm/persistent.py index e9789f339..557e2c6b2 100644 --- a/tritonbench/operators/fp8_gemm/persistent.py +++ b/tritonbench/operators/fp8_gemm/persistent.py @@ -1,4 +1,5 @@ from functools import lru_cache +from typing import Optional import torch import triton @@ -289,6 +290,36 @@ def matmul_configs(): } +def matmul_configs_blackwell(): + # Autotuner does not work with TMA. Use manual config. + return { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 4, + "num_warps": 4, # Note: num_warps >= 4 required for TMA + }, + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 2, + "num_warps": 2, + }, + torch.bfloat16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 2, + "num_warps": 2, + }, + } + + def allocate_matmul_tma(a, b): configs = matmul_configs() @@ -364,3 +395,159 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c): num_warps=configs[dtype]["num_warps"], # ) return c + + +# Blackwell Persistent + TMA +# Restrictions: +# - (K, N) must be a multiple of 16 on B200 for all benchmarks +# - num_warps >= 4 +# - TMA instructions expect at least a 128-thread group +# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps + + +def blackwell_persistent_tma(a, b, scale_a, scale_b, acc_dtype): + configs = matmul_configs_blackwell() + + # Check constraints. + assert ( + a.shape[1] == b.shape[1] + ), "Incompatible dimensions" # a.shape = (M, K), b.shape = (N, K) + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + shape_dtype = a.dtype # low-precision dtype, e.g. fp8 + + NUM_SMS = torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + + c = torch.zeros((M, N), device=a.device, dtype=acc_dtype) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device=a.device) + + if hasattr(triton, "set_allocator"): + triton.set_allocator(alloc_fn) + else: + return c + + if acc_dtype == torch.float16: + acc_dtype_tl = tl.float16 + elif acc_dtype == torch.bfloat16: + acc_dtype_tl = tl.bfloat16 + else: + raise NotImplementedError( + "Output types other than torch.float16 and torch.bfloat16 unsupported for FP8 Blackwell persistent + TMA kernels" + ) + + grid = lambda META: ( + min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), + ) + blackwell_persistent_tma_kernel[grid]( + a, + b, + c, # + M, + N, + K, # + scale_a.item(), # + scale_b.item(), # + BLOCK_SIZE_M=configs[shape_dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[shape_dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[shape_dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[shape_dtype]["GROUP_SIZE_M"], # + ACC_TYPE=acc_dtype_tl, + NUM_SMS=NUM_SMS, # + num_stages=configs[shape_dtype]["num_stages"], # + num_warps=configs[shape_dtype]["num_warps"], # + ) + return c + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def blackwell_persistent_tma_kernel( + a, + b, + acc, + M, # + N, # + K, # + scale_a, # + scale_b, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACC_TYPE: tl.constexpr, + NUM_SMS: tl.constexpr, +): # + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + a_desc = tl.make_tensor_descriptor( + a, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + b_desc = tl.make_tensor_descriptor( + b, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + acc_desc = tl.make_tensor_descriptor( + acc, + shape=[M, N], + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range( + start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True + ): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS + ) + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + accumulator = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 + ) # accumulate in high precision (fp32) for high accuracy + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a_block = a_desc.load([offs_am, offs_k]) + b_block = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32) + + accumulator *= scale_a * scale_b # currently only supports per-tensor scaling + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid( + tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS + ) + offs_cm = pid_m * BLOCK_SIZE_M + offs_cn = pid_n * BLOCK_SIZE_N + + c = accumulator.to(ACC_TYPE) + acc_desc.store([offs_cm, offs_cn], c) From 1f7aefcc6040c6b7f751b135ff3f8624d811f000 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Tue, 16 Sep 2025 23:18:36 -0700 Subject: [PATCH 2/5] Support per-row scaling (#424) Summary: Support per-row scaling for the FP8 Blackwell persistent + TMA kernel with warp specialization. Differential Revision: D82516347 --- tritonbench/operators/fp8_gemm/fp8_gemm.py | 4 +- tritonbench/operators/fp8_gemm/persistent.py | 50 +++++++++++++------- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 4c1eb3b4a..2730479f8 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -180,9 +180,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable: @register_benchmark(enabled=True) def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b): - return lambda: blackwell_persistent_tma( - a, b.T, scale_a, scale_b.T, self._get_dtype() - ) + return lambda: blackwell_persistent_tma(a, b.T, scale_a, scale_b.T, self._get_dtype(), self.extra_args.scaling_rowwise) @register_benchmark() def triton_fp8_gemm(self, a, b, scale_a, scale_b): diff --git a/tritonbench/operators/fp8_gemm/persistent.py b/tritonbench/operators/fp8_gemm/persistent.py index 557e2c6b2..3a5b1a1db 100644 --- a/tritonbench/operators/fp8_gemm/persistent.py +++ b/tritonbench/operators/fp8_gemm/persistent.py @@ -405,7 +405,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c): # - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps -def blackwell_persistent_tma(a, b, scale_a, scale_b, acc_dtype): +def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise): configs = matmul_configs_blackwell() # Check constraints. @@ -454,8 +454,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): M, N, K, # - scale_a.item(), # - scale_b.item(), # + scale_a_ptr, # + scale_b_ptr, # BLOCK_SIZE_M=configs[shape_dtype]["BLOCK_SIZE_M"], # BLOCK_SIZE_N=configs[shape_dtype]["BLOCK_SIZE_N"], # BLOCK_SIZE_K=configs[shape_dtype]["BLOCK_SIZE_K"], # @@ -464,12 +464,13 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): NUM_SMS=NUM_SMS, # num_stages=configs[shape_dtype]["num_stages"], # num_warps=configs[shape_dtype]["num_warps"], # + SCALING_ROWWISE=scaling_rowwise, ) return c @triton.jit -def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) @@ -485,14 +486,15 @@ def blackwell_persistent_tma_kernel( M, # N, # K, # - scale_a, # - scale_b, # + scale_a_ptr, # + scale_b_ptr, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # ACC_TYPE: tl.constexpr, NUM_SMS: tl.constexpr, + SCALING_ROWWISE: tl.constexpr, # ): # start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -522,12 +524,17 @@ def blackwell_persistent_tma_kernel( tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in tl.range( - start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True - ): - pid_m, pid_n = _compute_pid( - tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS - ) + if SCALING_ROWWISE: + # For row-wise scaling, we'll use the pointers as-is + scale_a = scale_a_ptr + scale_b = scale_b_ptr + else: + # For per-tensor scaling, we'll load the scalar values + scale_a = tl.load(scale_a_ptr) + scale_b = tl.load(scale_b_ptr) + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N @@ -540,12 +547,23 @@ def blackwell_persistent_tma_kernel( b_block = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32) - accumulator *= scale_a * scale_b # currently only supports per-tensor scaling + if SCALING_ROWWISE: + offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M) + offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N) + + scale_a_block = tl.load(scale_a + offs_scale_m, mask=offs_am < M, other=0.0) + scale_b_block = tl.load(scale_b + offs_scale_n, mask=offs_bn < N, other=0.0) + + a_scales = scale_a_block[:, None] + b_scales = scale_b_block[None, :] + else: + a_scales = scale_a + b_scales = scale_b + + accumulator *= a_scales * b_scales tile_id_c += NUM_SMS - pid_m, pid_n = _compute_pid( - tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS - ) + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M) offs_cm = pid_m * BLOCK_SIZE_M offs_cn = pid_n * BLOCK_SIZE_N From c911912fb1018fa84bcd9d9a6e13f08068f99071 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Tue, 16 Sep 2025 23:18:36 -0700 Subject: [PATCH 3/5] Add epilogue subtiling to persistent + TMA kernel (#425) Summary: Add epilogue subtiling to Blackwell FP8 persistent + TMA kernel in TritonBench. Epilogue subtiling breaks computation in the epilogue into multiple sub-tiles, allowing TMA to more efficiently overlap expensive computation (e.g. stores, scaling). Reviewed By: NikhilAPatel Differential Revision: D82462670 --- tritonbench/operators/fp8_gemm/persistent.py | 29 +++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/tritonbench/operators/fp8_gemm/persistent.py b/tritonbench/operators/fp8_gemm/persistent.py index 3a5b1a1db..15bad740b 100644 --- a/tritonbench/operators/fp8_gemm/persistent.py +++ b/tritonbench/operators/fp8_gemm/persistent.py @@ -294,12 +294,14 @@ def matmul_configs_blackwell(): # Autotuner does not work with TMA. Use manual config. return { torch.float8_e4m3fn: { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_stages": 4, "num_warps": 4, # Note: num_warps >= 4 required for TMA + "WARP_SPECIALIZE": True, + "EPILOGUE_SUBTILE": True, }, torch.float16: { "BLOCK_SIZE_M": 128, @@ -465,6 +467,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): num_stages=configs[shape_dtype]["num_stages"], # num_warps=configs[shape_dtype]["num_warps"], # SCALING_ROWWISE=scaling_rowwise, + WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], # + EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], # ) return c @@ -495,6 +499,8 @@ def blackwell_persistent_tma_kernel( ACC_TYPE: tl.constexpr, NUM_SMS: tl.constexpr, SCALING_ROWWISE: tl.constexpr, # + WARP_SPECIALIZE: tl.constexpr, + EPILOGUE_SUBTILE: tl.constexpr, ): # start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -518,7 +524,7 @@ def blackwell_persistent_tma_kernel( acc, shape=[M, N], strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N // 2] if EPILOGUE_SUBTILE else [BLOCK_SIZE_M, BLOCK_SIZE_N], ) tile_id_c = start_pid - NUM_SMS @@ -533,7 +539,7 @@ def blackwell_persistent_tma_kernel( scale_a = tl.load(scale_a_ptr) scale_b = tl.load(scale_b_ptr) - for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True): + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N @@ -567,5 +573,14 @@ def blackwell_persistent_tma_kernel( offs_cm = pid_m * BLOCK_SIZE_M offs_cn = pid_n * BLOCK_SIZE_N - c = accumulator.to(ACC_TYPE) - acc_desc.store([offs_cm, offs_cn], c) + if EPILOGUE_SUBTILE: + acc_reshaped = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc_permuted = tl.permute(acc_reshaped, (0, 2, 1)) + acc0, acc1 = tl.split(acc_permuted) + c0 = acc0.to(ACC_TYPE) + acc_desc.store([offs_cm, offs_cn], c0) + c1 = acc1.to(ACC_TYPE) + acc_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1) + else: + c = accumulator.to(ACC_TYPE) + acc_desc.store([offs_cm, offs_cn], c) From 2a989449b88fb826d8e5f2fcf4f55b04c287926f Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Tue, 16 Sep 2025 23:18:36 -0700 Subject: [PATCH 4/5] Add a Blackwell-specific scaled persistent + TMA template for GEMMs Summary: Add a Blackwell-specific scaled persistent + TMA Triton template to Inductor. This diff builds on D82515450 by adding a new set of mixins which inherit the scaling epilogue and add scaled persistent + TMA kwargs to the template. This diff also adds a benchmark for the scaled Blackwell persistent + TMA template to TritonBench `fp8_gemm`. Note that this diff is a minimal extension to the above diff; rather than adding a new kernel for the scaled version, we opted to simply extend the epilogue to account for scaling. This template is accurate for per-tensor and per-row scaling but may require modifications for other scaling modes, such as deepseek-style scaling, which apply scaling prior to the GEMM computation. In addition, note that epilogue subtiling is currently unsupported for both the scaled and non-scaled Blackwell templates, and functionality will be added in a subsequent diff. Differential Revision: D82597111 --- tritonbench/operators/fp8_gemm/fp8_gemm.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 2730479f8..9eb63935e 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -182,6 +182,22 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable: def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: blackwell_persistent_tma(a, b.T, scale_a, scale_b.T, self._get_dtype(), self.extra_args.scaling_rowwise) + @register_benchmark(enabled=True) + def blackwell_pt2_fp8_gemm(self, a, b, scale_a, scale_b): + torch._dynamo.reset() + with inductor_config.patch( + max_autotune=True, + max_autotune_gemm_backends="TRITON", + autotune_fallback_to_aten=False, + ): + f = lambda a, b: torch._scaled_mm( + a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_dtype() + ) + compiled = torch.compile(f, dynamic=False) + compiled(a, b) + + return lambda: compiled(a, b) + @register_benchmark() def triton_fp8_gemm(self, a, b, scale_a, scale_b): return lambda: tutorial_matmul(a, b) From 8ec0503fafe5b4e8b8d3bffc4ee39949f6950d4b Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Sun, 21 Sep 2025 19:22:02 -0700 Subject: [PATCH 5/5] Update fp8_gemm.py lint --- tritonbench/operators/fp8_gemm/fp8_gemm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 888de0350..fd27b48f6 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -198,7 +198,12 @@ def blackwell_pt2_fp8_gemm(self, a, b, scale_a, scale_b): autotune_fallback_to_aten=False, ): f = lambda a, b: torch._scaled_mm( - a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=self._get_dtype() + a, + b, + scale_a, + scale_b, + use_fast_accum=True, + out_dtype=self._get_dtype() ) compiled = torch.compile(f, dynamic=False) compiled(a, b)