Skip to content

Conversation

jananisriram
Copy link
Contributor

Summary:
Taking inspiration from D77053488, Triton's persistent matmul tutorial, and Triton's block scaled matmul tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Note the following limitations:

  • (K, N) (second matrix in GEMM) needs to be a multiple of 16
  • num_warps >= 4: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Differential Revision: D81470285

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D81470285

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D81470285

facebook-github-bot pushed a commit that referenced this pull request Sep 16, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
@facebook-github-bot
Copy link
Contributor

@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285.

facebook-github-bot pushed a commit that referenced this pull request Sep 16, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
facebook-github-bot pushed a commit that referenced this pull request Sep 16, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
jananisriram added a commit that referenced this pull request Sep 16, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
jananisriram added a commit that referenced this pull request Sep 16, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
@facebook-github-bot
Copy link
Contributor

@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285.

jananisriram added a commit that referenced this pull request Sep 16, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
jananisriram added a commit that referenced this pull request Sep 16, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
jananisriram added a commit that referenced this pull request Sep 16, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
@facebook-github-bot
Copy link
Contributor

@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285.

facebook-github-bot pushed a commit that referenced this pull request Sep 17, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
facebook-github-bot pushed a commit that referenced this pull request Sep 17, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
@facebook-github-bot
Copy link
Contributor

@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285.

facebook-github-bot pushed a commit that referenced this pull request Sep 17, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
@facebook-github-bot
Copy link
Contributor

@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285.

jananisriram added a commit that referenced this pull request Sep 17, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
jananisriram added a commit that referenced this pull request Sep 17, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Differential Revision: D81470285
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Reviewed By: njriasan

Differential Revision: D81470285
@facebook-github-bot
Copy link
Contributor

@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285.

jananisriram added a commit that referenced this pull request Sep 18, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Reviewed By: njriasan

Differential Revision: D81470285
facebook-github-bot pushed a commit that referenced this pull request Sep 18, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Reviewed By: njriasan

Differential Revision: D81470285
facebook-github-bot pushed a commit that referenced this pull request Sep 18, 2025
Summary:

Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.

Recall the following for FP8 workloads:
- Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`)
- Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`)
- Accumulation: `torch.float32` (`tl.float32`).

Note the following limitations:
- `(K, N)` (second matrix in GEMM) needs to be a multiple of 16
- `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.

Note that the current kernel is being autotuned on just one config; this will be changed in a future diff.

Reviewed By: njriasan

Differential Revision: D81470285
@jananisriram
Copy link
Contributor Author

@pytorchbot merge

Copy link

pytorch-bot bot commented Sep 18, 2025

Mergebot is not configured for this repository. Please use the merge button provided by GitHub.

@jananisriram
Copy link
Contributor Author

@pytorchbot merge

Copy link

pytorch-bot bot commented Sep 18, 2025

Mergebot is not configured for this repository. Please use the merge button provided by GitHub.

@facebook-github-bot facebook-github-bot merged commit 434ca4c into main Sep 18, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants