1- # Float8 MoE Training
1+ # Low precision MoE Training
22
3- This prototype feature provides a way to use float8 rowwise training on MoE layers.
3+ This prototype provides:
44
5- Below is a simple runnable example of how to use this feature, using the MoE layer
6- from the [ torchtitan] ( https://github.com/pytorch/torchtitan ) Llama4 implementation for demonstration.
5+ 1 . Quantized building block for low precision MoE training: ` _scaled_grouped_mm ` . It is a differentiable drop-in replacement for ` torch._grouped_mm ` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [ example] ( #torchao_scaled_grouped_mm-example-forward--backward-pass ) of a forward and backward pass below.
6+ - Using MXFP8 on a B200 GPU, this provides:
7+ - ~ 1.4x - 1.8x speedups over bfloat16 ` torch._grouped_mm ` for Llama4 17b 16e shapes (depending on the ` M ` dimension, i.e. batch_size * seq_len)
8+ - ~ 1.15 - 1.3x speedups over bfloat16 ` torch._grouped_mm ` for DeepSeekV3 671b shapes (depending on the ` M ` dimension, i.e. batch_size * seq_len)
79
810
11+ 2 . [ TorchTitan] ( https://github.com/pytorch/torchtitan/tree/main ) integration of torchao's dynamically quantized ` _scaled_grouped_mm ` : pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: ` --model.converters="quantize.grouped_mm.mx" [--quantize.grouped_mm.mx.fqns="experts"] `
12+
13+ 3 . ` quantize_(...) ` API support for model conversion: this swaps all ` torch._grouped_mm ` ops in your model definition to use torchao ` _scaled_grouped_mm ` under the hood (see [ example] ( #model-conversion-api-example-end-to-end-training ) below).
14+
15+
16+ ## Table of Contents
17+
18+ - [ Examples] ( #examples )
19+ - [ Performance Benchmarks] ( #performance-benchmarks-mxfp8 )
20+ - [ System Requirements] ( #system-requirements )
21+ - [ Implementation Details for Developers] ( #implementation-details-for-developers )
22+ - [ Limitations] ( #limitations )
23+
24+ ## Examples
25+ #### torchao_scaled_grouped_mm example: forward + backward pass
26+ ``` python
27+ import torch
28+ from torch.nn import functional as F
29+ from torchao.prototype.moe_training import (
30+ _scaled_grouped_mm as torchao_scaled_grouped_mm
31+ )
32+ from torchao.prototype.moe_training.conversion_utils import MoEScalingType
33+ from torchao.prototype.moe_training.utils import generate_jagged_offs
34+
35+ num_groups, total_M, N, K = 8 , 131072 , 8192 , 5120
36+
37+ # A = input actvations, B = expert weights
38+ A = torch.randn(total_M, K, dtype = torch.bfloat16, device = " cuda" , requires_grad = True )
39+ B = torch.randn(num_groups, N, K, dtype = torch.bfloat16, device = " cuda" , requires_grad = True )
40+
41+ # Token group offsets computed by router in actual MoE layer
42+ offs = generate_jagged_offs(num_groups, total_M, device = " cuda" )
43+
44+ # Forward and backward example
45+ out = torchao_scaled_grouped_mm(
46+ A,
47+ B.transpose(- 2 , - 1 ),
48+ offs = offs,
49+ scaling_type = MoEScalingType.MXFP8 ,
50+ )
51+
52+ # (Fake labels for demonstration purposes)
53+ labels = torch.ones_like(out)
54+ loss = F.mse_loss(out, labels)
55+ loss.backward()
56+ ```
57+
58+ #### Model conversion API example: end-to-end training
959``` python
1060import torch
1161from torch import nn
1262from torch.nn import functional as F
1363
14- # this feature requires CUDA and SM89 +
15- assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8 , 9 )
64+ # this feature requires CUDA 12.8+ and SM100 +
65+ assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10 , 0 )
1666
1767from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
1868from torchao.quantization.quant_api import quantize_
1969
2070# this example uses torchtitan llama4 MoE, see
71+ # this benchmark requires torchtitan
2172try :
22- from torchtitan.experiments.llama4.model.args import TransformerModelArgs
23- from torchtitan.experiments.llama4.model.moe import MoE
24- except ImportError as e:
25- raise ImportError (
26- " torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
27- ) from e
73+ from torchtitan.distributed.expert_parallel import (
74+ set_token_group_alignment_size_m,
75+ )
76+ from torchtitan.models.moe import MoE, MoEArgs
77+ except ImportError :
78+ pytest.skip(
79+ " torchtitan not installed, skipping MoE tests." , allow_module_level = True
80+ )
2881
2982
3083# initialize model
3184device = torch.device(" cuda" )
32- model_args = TransformerModelArgs(
33- moe_enabled = True ,
85+ moe_args = MoEArgs(
3486 num_experts = 8 ,
35- dim = 256 ,
3687)
37- model = MoE(model_args).to(torch.bfloat16).to(device)
88+ dim, hidden_dim = 5120 , 8192
89+ model = MoE(moe_args, dim, hidden_dim).to(torch.bfloat16).to(device)
3890init_std = 0.02
3991model.init_weights(init_std, device)
4092
@@ -48,15 +100,18 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
48100 return True
49101 return False
50102
103+ # Token group alignment size must be 32 for MXFP8 training
104+ alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16
105+ set_token_group_alignment_size_m(alignment_size)
51106
52107# quantize the model
53108config = MoETrainingConfig()
54109quantize_(model, config = config, filter_fn = moe_module_filter_fn)
55110
56111# training loop
57112optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3 )
113+ batch_size, seq_len = 2 , 2048
58114for step in range (10 ):
59- batch, seq, dim = 8 , 2048 , 256
60115 x = torch.randn(
61116 batch, seq, dim, dtype = torch.bfloat16, requires_grad = True , device = device
62117 )
@@ -75,11 +130,97 @@ for step in range(10):
75130
76131```
77132
78- ## Requirements
79- - torchao nightly build
80- - CUDA compute capability 8.9+ (SM89+)
133+ ## System requirements
134+ - torchao 0.14+
135+ - For MXFP8 MoE training, CUDA 12.8+ and SM100+ GPU arch are required.
136+ - For FP8 rowwise MoE training, CUDA 12.4+ and SM89+ GPU arch are required.
137+
138+ ## Performance benchmarks: MXFP8
139+
140+
141+ ### Single MoE layer forward + backward pass vs bfloat16 baseline
142+
143+ | Model | total_M | N | K | bf16 time (ms) | mxfp8 time (ms) | speedup |
144+ | --------------| ---------| ------| ------| ---------------| -----------------| ---------|
145+ | Llama4 16e | 131072 | 8192 | 5120 | 275.270 | 192.420 | 1.431x |
146+ | DeepSeekV3 | 131072 | 2048 | 7168 | 92.032 | 80.182 | 1.148x |
147+
148+ To reproduce these benchmarks, on a B200 GPU machine, run the following commands:
149+
150+ Llama4 17b 16e shapes:
151+ ``` bash
152+ CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=5120 --hidden_dim=8192 --local_num_experts=8
153+ ```
154+
155+ DeepSeekV3 671b shapes:
156+ ``` bash
157+ CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=7168 --hidden_dim=2048 --local_num_experts=8
158+ ```
159+
160+ ### Individual bfloat16 torch._ grouped_mm op vs torchao_scaled_grouped_mm
161+
162+ MXFP8:
163+
164+ | M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup |
165+ | ------------------------| -----------------| -------------------| ------------------------|
166+ | (128000, 8192, 5120, 1) | 40463 | 24406 | 1.658x |
167+ | (128000, 8192, 5120, 2) | 35494.5 | 24705.1 | 1.437x |
168+ | (128000, 8192, 5120, 4) | 38879.3 | 24508.5 | 1.586x |
169+ | (128000, 8192, 5120, 8) | 35714.6 | 25937.6 | 1.377x |
170+ | (128000, 1536, 5120, 1) | 6353.06 | 7401.54 | 0.858x |
171+ | (128000, 1536, 5120, 2) | 6511.65 | 6729.33 | 0.968x |
172+ | (128000, 1536, 5120, 4) | 6455.2 | 6626.5 | 0.974x |
173+ | (128000, 1536, 5120, 8) | 7716.13 | 6516.74 | 1.184x |
174+ | (128000, 2048, 7168, 1) | 11758 | 11255.7 | 1.045x |
175+ | (128000, 2048, 7168, 2) | 15012.9 | 9917.9 | 1.514x |
176+ | (128000, 2048, 7168, 4) | 14904.2 | 10493.8 | 1.42x |
177+ | (128000, 2048, 7168, 8) | 13178 | 9638.38 | 1.367x |
178+
179+
180+ To reproduce this benchmark, on a B200 GPU machine, run the following command:
181+ - ` python benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py --compile `
182+ - torchao: ` 0.14.0+gitc7b8e13da `
183+ - torch: ` 2.10.0a0+gitf6de195 `
184+
185+
186+ #### End-to-end training: Llama4 16e MoE layer vs bfloat16 baseline with TorchTitan
187+ - Single node benchmarks with 4xB200
188+ - Llama4 16e default configs; FSDP=4, EP=4; AC=none; compile=True; seq_len=8192; local_bs=8
189+ - Reduced num layers from 48 -> 2 to avoid OOM in single node setting
190+ - TorchTitan debug model config (same as Llama4 17bx16e, but with 2 layers):
191+
192+
193+ | Configuration | Throughput (Median Tokens/s) | Max Memory (GiB) | Speedup over bf16
194+ | :---------------------------------------------------------------------------| -----------------------------:| ------------------| ------------------|
195+ | bf16 baseline | 49381.0 | 145.55 | -
196+ | MXFP8 for Linears only | 52038.0 | 146.62 | 1.053x
197+ | MXFP8 for Grouped GEMMs only | 69350.0 | 144.71 | 1.404x
198+ | MXFP8 for Linears + Grouped GEMMs | 70747.0 | 145.32 | 1.433x
199+
200+ #### Commands to reproduce these benchmarks:
201+
202+ bfloat16 baseline:
203+ ```
204+ rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=200 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion" ./llama4.sh
205+ ```
206+
207+ MXFP8 dense only:
208+ ```
209+ rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=200 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion --model.converters="quantize.linear.mx"" ./llama4.sh
210+ ```
211+
212+ MXFP8 MoE only:
213+ ```
214+ rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=200 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion --model.converters="quantize.grouped_mm.mx"" ./llama4.sh
215+ ```
216+
217+ MXFP8 MoE + Dense:
218+ ```
219+ rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=50 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion --model.converters="quantize.grouped_mm.mx,quantize.linear.mx"" ./llama4.sh
220+ ```
221+
81222
82- ## Modeling requirements
223+ ## Implementation details for developers
83224This prototype is specifically designed to be used on MoE models using
84225` torch._grouped_mm ` to implement expert computation in token-choice routing,
85226where expert weights are implemented as 3D nn.Parameters with ` num_experts ` as
@@ -97,5 +238,4 @@ operands in both the forward and backward pass.
97238For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
98239
99240## Limitations
100- - Only tested with eager mode, single GPU training so far.
101- - Composability with parallelisms and ` torch.compile ` are next steps.
241+ - The new CUDA kernel for MXFP8 quantization of the non-transposed expert weights in the backwards pass does not support TP yet.
0 commit comments