Skip to content

Commit 3d509d9

Browse files
[mxfp8 moe training] mxfp8 all_to_all_vdev_2d kernel
1 parent d2fae7a commit 3d509d9

File tree

6 files changed

+825
-1
lines changed

6 files changed

+825
-1
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import pytest
2+
3+
pytest.importorskip("triton")
4+
5+
import torch
6+
import torch.distributed as dist
7+
import torch.distributed._symmetric_memory as symm_mem
8+
from torch.distributed._functional_collectives import (
9+
all_to_all_single_autograd,
10+
)
11+
from torch.nn import functional as F
12+
from torch.testing._internal.common_distributed import (
13+
MultiProcessTestCase,
14+
)
15+
from torch.testing._internal.common_utils import (
16+
instantiate_parametrized_tests,
17+
run_tests,
18+
)
19+
20+
from torchao.float8.float8_utils import (
21+
compute_error,
22+
)
23+
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
24+
mxfp8_on_device_all_to_all_v,
25+
)
26+
27+
28+
@instantiate_parametrized_tests
29+
class TritonAllReduceTest(MultiProcessTestCase):
30+
def setUp(self) -> None:
31+
super().setUp()
32+
self._spawn_processes()
33+
34+
@property
35+
def world_size(self) -> int:
36+
return 2
37+
38+
@property
39+
def device(self) -> torch.device:
40+
return torch.device(f"cuda:{self.rank}")
41+
42+
def _init_process(self):
43+
torch.cuda.set_device(self.device)
44+
store = dist.FileStore(self.file_name, self.world_size)
45+
dist.init_process_group(
46+
backend="nccl",
47+
world_size=self.world_size,
48+
rank=self.rank,
49+
store=store,
50+
)
51+
torch.manual_seed(42 + self.rank)
52+
53+
def _init_device(self):
54+
symm_mem.set_backend("NVSHMEM")
55+
56+
def test_a2a_fwd_bwd(self):
57+
self._init_process()
58+
try:
59+
torch.manual_seed(42 + self.rank)
60+
self._init_device()
61+
62+
group_name = dist.group.WORLD.group_name
63+
symm_mem.enable_symm_mem_for_group(group_name)
64+
65+
tokens_per_ep_rank = 8192
66+
dim = 2048
67+
input_tensor = torch.randn(
68+
tokens_per_ep_rank,
69+
dim,
70+
device=self.device,
71+
dtype=torch.float32,
72+
requires_grad=True,
73+
)
74+
ref_input_tensor = input_tensor.detach().clone().requires_grad_(True)
75+
76+
# Generate random input splits that sum to tokens_per_ep_rank
77+
num_splits = self.world_size
78+
input_splits = generate_split_sizes(
79+
num_splits, tokens_per_ep_rank, self.device
80+
)
81+
82+
# Max output tokens per rank is worst case where one rank receives all tokens
83+
max_output_tokens_per_rank = tokens_per_ep_rank * self.world_size
84+
85+
# Test forward
86+
output, output_splits = mxfp8_on_device_all_to_all_v(
87+
input_tensor,
88+
input_splits,
89+
max_output_tokens_per_rank,
90+
group_name,
91+
)
92+
93+
# Reference torch.all_to_all_single to compare against
94+
output_splits_ref = torch.empty_like(output_splits)
95+
96+
# Compute output splits from input splits
97+
dist.all_to_all_single(output_splits_ref, input_splits)
98+
99+
# Pre-allocate output buffer for reference a2a
100+
total_tokens_on_rank_after_a2a = output_splits_ref.sum()
101+
ref_output = torch.empty(
102+
total_tokens_on_rank_after_a2a,
103+
dim,
104+
device=self.device,
105+
dtype=torch.float32,
106+
)
107+
108+
# Do the actual all_to_all_single
109+
ref_output = all_to_all_single_autograd(
110+
ref_input_tensor,
111+
output_splits_ref.tolist(),
112+
input_splits.tolist(),
113+
dist.group.WORLD,
114+
)
115+
116+
# Compare output
117+
assert torch.equal(output_splits, output_splits_ref), (
118+
"output_splits mismatch"
119+
)
120+
out_no_padding = output[:total_tokens_on_rank_after_a2a]
121+
sqnr = compute_error(ref_output, out_no_padding)
122+
min_sqnr = 30.0
123+
assert sqnr > min_sqnr, f"sqnr={sqnr} is less than min_sqnr={min_sqnr}"
124+
125+
# Test backwards
126+
labels = torch.ones_like(out_no_padding)
127+
loss = F.mse_loss(out_no_padding, labels)
128+
ref_loss = F.mse_loss(ref_output, labels)
129+
loss.backward()
130+
ref_loss.backward()
131+
132+
# Compare grads
133+
grad_sqnr = compute_error(ref_input_tensor.grad, input_tensor.grad)
134+
min_grad_sqnr = 28.0
135+
assert grad_sqnr > min_grad_sqnr, (
136+
f"grad_sqnr={grad_sqnr} is less than min_grad_sqnr={min_grad_sqnr}"
137+
)
138+
139+
finally:
140+
dist.destroy_process_group()
141+
142+
143+
def generate_split_sizes(K: int, N: int, device: str = "cpu") -> torch.Tensor:
144+
"""
145+
Generates a tensor of K random non-negative integers that sum to N.
146+
"""
147+
if K <= 0:
148+
raise ValueError("K must be a positive integer.")
149+
if N < 0:
150+
raise ValueError("N must be a non-negative integer.")
151+
152+
if K == 1:
153+
return torch.tensor([N], dtype=torch.long, device=device)
154+
155+
# Generate K-1 random "dividers" in the range [0, N].
156+
dividers = torch.randint(0, N + 1, (K - 1,), device=device)
157+
158+
# Add 0 and N to the set of dividers to form the boundaries.
159+
boundaries = torch.cat(
160+
[torch.tensor([0], device=device), dividers, torch.tensor([N], device=device)]
161+
)
162+
163+
# Sort the boundaries to ensure they are in order
164+
sorted_boundaries = torch.sort(boundaries).values
165+
166+
# The K integers are the differences between consecutive boundaries (will sum to N)
167+
result = sorted_boundaries[1:] - sorted_boundaries[:-1]
168+
169+
return result.to(dtype=torch.int64)
170+
171+
172+
if __name__ == "__main__":
173+
run_tests()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
2+
compute_blocked_scale_offsets_for_K_groups, # noqa: F401
3+
compute_blocked_scale_offsets_for_M_groups, # noqa: F401
4+
mxfp8_quantize_cuda_3d, # noqa: F401
5+
torch_to_blocked_2d_K_groups, # noqa: F401
6+
torch_to_blocked_2d_M_groups, # noqa: F401
7+
torch_to_blocked_per_group_3d, # noqa: F401
8+
triton_mx_block_rearrange_2d_K_groups, # noqa: F401
9+
triton_mx_block_rearrange_2d_M_groups, # noqa: F401
10+
triton_mx_block_rearrange_per_group_3d, # noqa: F401
11+
)

0 commit comments

Comments
 (0)