-
Notifications
You must be signed in to change notification settings - Fork 45
Prototype FP8 Blackwell persistent + TMA kernel with warp specialization #385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pull request was exported from Phabricator. Differential Revision: D81470285 |
e6578cb
to
11ddc62
Compare
This pull request was exported from Phabricator. Differential Revision: D81470285 |
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
11ddc62
to
8e64a26
Compare
@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285. |
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
8e64a26
to
13ceedb
Compare
@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285. |
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
13ceedb
to
ee54388
Compare
@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285. |
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
0c68de7
to
aa5ffc7
Compare
@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285. |
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
aa5ffc7
to
ef3c2a9
Compare
@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285. |
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Differential Revision: D81470285
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Reviewed By: njriasan Differential Revision: D81470285
ef3c2a9
to
161229f
Compare
@jananisriram has exported this pull request. If you are a Meta employee, you can view the originating diff in D81470285. |
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Reviewed By: njriasan Differential Revision: D81470285
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Reviewed By: njriasan Differential Revision: D81470285
Summary: Taking inspiration from D77053488, Triton's [persistent matmul](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#) tutorial, and Triton's [block scaled matmul](https://triton-lang.org/main/getting-started/tutorials/10-block-scaled-matmul.html) tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening. Recall the following for FP8 workloads: - Input shapes: `torch.float8_e4m3fn` (`tl.float8e4nv`) - Output: `torch.float16` (per-tensor, `tl.float16`) or `torch.bfloat16` (per-row, `tl.bfloat16`) - Accumulation: `torch.float32` (`tl.float32`). Note the following limitations: - `(K, N)` (second matrix in GEMM) needs to be a multiple of 16 - `num_warps >= 4`: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality. Note that the current kernel is being autotuned on just one config; this will be changed in a future diff. Reviewed By: njriasan Differential Revision: D81470285
@pytorchbot merge |
Mergebot is not configured for this repository. Please use the merge button provided by GitHub. |
@pytorchbot merge |
Mergebot is not configured for this repository. Please use the merge button provided by GitHub. |
Summary:
Taking inspiration from D77053488, Triton's persistent matmul tutorial, and Triton's block scaled matmul tutorial, write and benchmark a persistent + TMA Triton kernel for FP8 workloads on Blackwell which enables warp specialization and flattening.
Note the following limitations:
(K, N)
(second matrix in GEMM) needs to be a multiple of 16num_warps >= 4
: TMA instructions expect at least a 128-thread (4 warps) group. 1 warp = 32 threads, so we need 4 warps per thread block to ensure correct functionality.Differential Revision: D81470285