Skip to content

Commit 838dceb

Browse files
[moe training] update readme (#3163)
* [moe training] update readme * add repro commands for titan benchmarks * clean up readme
1 parent e9c98c0 commit 838dceb

File tree

1 file changed

+163
-23
lines changed

1 file changed

+163
-23
lines changed

torchao/prototype/moe_training/README.md

Lines changed: 163 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,92 @@
1-
# Float8 MoE Training
1+
# Low precision MoE Training
22

3-
This prototype feature provides a way to use float8 rowwise training on MoE layers.
3+
This prototype provides:
44

5-
Below is a simple runnable example of how to use this feature, using the MoE layer
6-
from the [torchtitan](https://github.com/pytorch/torchtitan) Llama4 implementation for demonstration.
5+
1. Quantized building block for low precision MoE training: `_scaled_grouped_mm`. It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
6+
- Using MXFP8 on a B200 GPU, this provides:
7+
- ~1.4x - 1.8x speedups over bfloat16 `torch._grouped_mm` for Llama4 17b 16e shapes (depending on the `M` dimension, i.e. batch_size * seq_len)
8+
- ~1.15 - 1.3x speedups over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes (depending on the `M` dimension, i.e. batch_size * seq_len)
79

810

11+
2. [TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration of torchao's dynamically quantized `_scaled_grouped_mm`: pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: `--model.converters="quantize.grouped_mm.mx" [--quantize.grouped_mm.mx.fqns="experts"]`
12+
13+
3. `quantize_(...)` API support for model conversion: this swaps all `torch._grouped_mm` ops in your model definition to use torchao `_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below).
14+
15+
16+
## Table of Contents
17+
18+
- [Examples](#examples)
19+
- [Performance Benchmarks](#performance-benchmarks-mxfp8)
20+
- [System Requirements](#system-requirements)
21+
- [Implementation Details for Developers](#implementation-details-for-developers)
22+
- [Limitations](#limitations)
23+
24+
## Examples
25+
#### torchao_scaled_grouped_mm example: forward + backward pass
26+
```python
27+
import torch
28+
from torch.nn import functional as F
29+
from torchao.prototype.moe_training import (
30+
_scaled_grouped_mm as torchao_scaled_grouped_mm
31+
)
32+
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
33+
from torchao.prototype.moe_training.utils import generate_jagged_offs
34+
35+
num_groups, total_M, N, K = 8, 131072, 8192, 5120
36+
37+
# A = input actvations, B = expert weights
38+
A = torch.randn(total_M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
39+
B = torch.randn(num_groups, N, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
40+
41+
# Token group offsets computed by router in actual MoE layer
42+
offs = generate_jagged_offs(num_groups, total_M, device="cuda")
43+
44+
# Forward and backward example
45+
out = torchao_scaled_grouped_mm(
46+
A,
47+
B.transpose(-2, -1),
48+
offs=offs,
49+
scaling_type=MoEScalingType.MXFP8,
50+
)
51+
52+
# (Fake labels for demonstration purposes)
53+
labels = torch.ones_like(out)
54+
loss = F.mse_loss(out, labels)
55+
loss.backward()
56+
```
57+
58+
#### Model conversion API example: end-to-end training
959
```python
1060
import torch
1161
from torch import nn
1262
from torch.nn import functional as F
1363

14-
# this feature requires CUDA and SM89+
15-
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
64+
# this feature requires CUDA 12.8+ and SM100+
65+
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0)
1666

1767
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
1868
from torchao.quantization.quant_api import quantize_
1969

2070
# this example uses torchtitan llama4 MoE, see
71+
# this benchmark requires torchtitan
2172
try:
22-
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
23-
from torchtitan.experiments.llama4.model.moe import MoE
24-
except ImportError as e:
25-
raise ImportError(
26-
"torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
27-
) from e
73+
from torchtitan.distributed.expert_parallel import (
74+
set_token_group_alignment_size_m,
75+
)
76+
from torchtitan.models.moe import MoE, MoEArgs
77+
except ImportError:
78+
pytest.skip(
79+
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
80+
)
2881

2982

3083
# initialize model
3184
device = torch.device("cuda")
32-
model_args = TransformerModelArgs(
33-
moe_enabled=True,
85+
moe_args = MoEArgs(
3486
num_experts=8,
35-
dim=256,
3687
)
37-
model = MoE(model_args).to(torch.bfloat16).to(device)
88+
dim, hidden_dim = 5120, 8192
89+
model = MoE(moe_args, dim, hidden_dim).to(torch.bfloat16).to(device)
3890
init_std = 0.02
3991
model.init_weights(init_std, device)
4092

@@ -48,15 +100,18 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
48100
return True
49101
return False
50102

103+
# Token group alignment size must be 32 for MXFP8 training
104+
alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16
105+
set_token_group_alignment_size_m(alignment_size)
51106

52107
# quantize the model
53108
config = MoETrainingConfig()
54109
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
55110

56111
# training loop
57112
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
113+
batch_size, seq_len = 2, 2048
58114
for step in range(10):
59-
batch, seq, dim = 8, 2048, 256
60115
x = torch.randn(
61116
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
62117
)
@@ -75,11 +130,97 @@ for step in range(10):
75130

76131
```
77132

78-
## Requirements
79-
- torchao nightly build
80-
- CUDA compute capability 8.9+ (SM89+)
133+
## System requirements
134+
- torchao 0.14+
135+
- For MXFP8 MoE training, CUDA 12.8+ and SM100+ GPU arch are required.
136+
- For FP8 rowwise MoE training, CUDA 12.4+ and SM89+ GPU arch are required.
137+
138+
## Performance benchmarks: MXFP8
139+
140+
141+
### Single MoE layer forward + backward pass vs bfloat16 baseline
142+
143+
| Model | total_M | N | K | bf16 time (ms) | mxfp8 time (ms) | speedup |
144+
|--------------|---------|------|------|---------------|-----------------|---------|
145+
| Llama4 16e | 131072 | 8192 | 5120 | 275.270 | 192.420 | 1.431x |
146+
| DeepSeekV3 | 131072 | 2048 | 7168 | 92.032 | 80.182 | 1.148x |
147+
148+
To reproduce these benchmarks, on a B200 GPU machine, run the following commands:
149+
150+
Llama4 17b 16e shapes:
151+
```bash
152+
CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=5120 --hidden_dim=8192 --local_num_experts=8
153+
```
154+
155+
DeepSeekV3 671b shapes:
156+
```bash
157+
CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=7168 --hidden_dim=2048 --local_num_experts=8
158+
```
159+
160+
### Individual bfloat16 torch._grouped_mm op vs torchao_scaled_grouped_mm
161+
162+
MXFP8:
163+
164+
| M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup |
165+
|------------------------|-----------------|-------------------|------------------------|
166+
| (128000, 8192, 5120, 1) | 40463 | 24406 | 1.658x |
167+
| (128000, 8192, 5120, 2) | 35494.5 | 24705.1 | 1.437x |
168+
| (128000, 8192, 5120, 4) | 38879.3 | 24508.5 | 1.586x |
169+
| (128000, 8192, 5120, 8) | 35714.6 | 25937.6 | 1.377x |
170+
| (128000, 1536, 5120, 1) | 6353.06 | 7401.54 | 0.858x |
171+
| (128000, 1536, 5120, 2) | 6511.65 | 6729.33 | 0.968x |
172+
| (128000, 1536, 5120, 4) | 6455.2 | 6626.5 | 0.974x |
173+
| (128000, 1536, 5120, 8) | 7716.13 | 6516.74 | 1.184x |
174+
| (128000, 2048, 7168, 1) | 11758 | 11255.7 | 1.045x |
175+
| (128000, 2048, 7168, 2) | 15012.9 | 9917.9 | 1.514x |
176+
| (128000, 2048, 7168, 4) | 14904.2 | 10493.8 | 1.42x |
177+
| (128000, 2048, 7168, 8) | 13178 | 9638.38 | 1.367x |
178+
179+
180+
To reproduce this benchmark, on a B200 GPU machine, run the following command:
181+
- `python benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py --compile`
182+
- torchao: `0.14.0+gitc7b8e13da`
183+
- torch: `2.10.0a0+gitf6de195`
184+
185+
186+
#### End-to-end training: Llama4 16e MoE layer vs bfloat16 baseline with TorchTitan
187+
- Single node benchmarks with 4xB200
188+
- Llama4 16e default configs; FSDP=4, EP=4; AC=none; compile=True; seq_len=8192; local_bs=8
189+
- Reduced num layers from 48 -> 2 to avoid OOM in single node setting
190+
- TorchTitan debug model config (same as Llama4 17bx16e, but with 2 layers):
191+
192+
193+
| Configuration | Throughput (Median Tokens/s) | Max Memory (GiB) | Speedup over bf16
194+
|:---------------------------------------------------------------------------|-----------------------------:|------------------|------------------|
195+
| bf16 baseline | 49381.0 | 145.55 | -
196+
| MXFP8 for Linears only | 52038.0 | 146.62 | 1.053x
197+
| MXFP8 for Grouped GEMMs only | 69350.0 | 144.71 | 1.404x
198+
| MXFP8 for Linears + Grouped GEMMs | 70747.0 | 145.32 | 1.433x
199+
200+
#### Commands to reproduce these benchmarks:
201+
202+
bfloat16 baseline:
203+
```
204+
rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=200 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion" ./llama4.sh
205+
```
206+
207+
MXFP8 dense only:
208+
```
209+
rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=200 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion --model.converters="quantize.linear.mx"" ./llama4.sh
210+
```
211+
212+
MXFP8 MoE only:
213+
```
214+
rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=200 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion --model.converters="quantize.grouped_mm.mx"" ./llama4.sh
215+
```
216+
217+
MXFP8 MoE + Dense:
218+
```
219+
rm -rf /tmp/torchinductor_${USER}; CUDA_VISIBLE_DEVICES="4,5,6,7" TORCHTITAN_ROOT=/home/${USER}/torchtitan NGPU=4 EXTRA_ARGS="--metrics.log_freq=10 --training.steps=50 --parallelism.data_parallel_shard_degree=4 --parallelism.expert_parallel_degree=4 --parallelism.tensor_parallel_degree=1 --compile.enable --training.seq_len=8192 --activation_checkpoint.mode=none --model.print_after_conversion --model.converters="quantize.grouped_mm.mx,quantize.linear.mx"" ./llama4.sh
220+
```
221+
81222

82-
## Modeling requirements
223+
## Implementation details for developers
83224
This prototype is specifically designed to be used on MoE models using
84225
`torch._grouped_mm` to implement expert computation in token-choice routing,
85226
where expert weights are implemented as 3D nn.Parameters with `num_experts` as
@@ -97,5 +238,4 @@ operands in both the forward and backward pass.
97238
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
98239

99240
## Limitations
100-
- Only tested with eager mode, single GPU training so far.
101-
- Composability with parallelisms and `torch.compile` are next steps.
241+
- The new CUDA kernel for MXFP8 quantization of the non-transposed expert weights in the backwards pass does not support TP yet.

0 commit comments

Comments
 (0)