Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 23, 2025

Summary

  • Trace analysis of Llama4 mxfp8 training on b200 (FSDP=4, EP=4, seq_len=8192, local_batch_size=8) shows exposed a2a comms constitute a huge % of total runtime (see trace).
  • We can improve this by doing dynamic quant on inputs -> a2a on mxfp8 data and scales -> dequantize outputs, which should reduce bytes sent over the network by ~45%. There will be some overhead from dynamic quant/dequant, but this is on the order of microseconds, and the a2a here is 200+ms, so should be a good trade-off and solid net perf benefit.
  • My main concern here is how I'm allocating symmetric memory buffers and necessary rendezvous overhead - the idea is for this to only happen once on the first forward pass then be re-used, but remains to be seen if this works as I expect in practice.
Screenshot 2025-09-23 at 12 58 12 PM

Changes

Using the bf16 all_to_all_v kernel from torchtitan as a starting point, this PR implements:

  • _mxfp8_on_device_all_to_all_v kernel using Triton + Symmetric Memory for inter-device comms
  • Autograd func wrapper to make it differentiable
  • Unit tests verifying correctness of both forward and backward pass, by comparing against standard fp32 all_to_all_single_autograd and validating SQNR is high (>28).
  • I also refactored to separate quantization kernels and comms kernels into separate files/modules.

Benchmarks

input_shape        num_splits    bf16_us    mxfp8_us
---------------  ------------  ---------  ----------
(8, 8192, 5120)             8     4032.3     1256.14

Next steps

  • Refactor to use NVSHMEM apis directly for inter-node a2a
  • torchtitan integration by extending this configurable EP a2a pattern
  • Extend to all_to_all_v_2d to fuse in the local token shuffle so we can do it in low precision, moving fewer bytes around -> making it faster, then dequantize only after the token shuffle.
  • Look into loop unrolling, which is used in the bf16 implementation

Copy link

pytorch-bot bot commented Sep 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3048

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ No Failures, 3 Pending

As of commit 0f381e3 with merge base d2fae7a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Sep 23, 2025
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 23, 2025
@danielvegamyhre danielvegamyhre added mx moe and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Sep 23, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft September 23, 2025 17:47
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 23, 2025
@danielvegamyhre danielvegamyhre force-pushed the a2a branch 7 times, most recently from cb2a361 to 12bb6ae Compare September 24, 2025 21:59
@danielvegamyhre danielvegamyhre changed the title [WIP] [mxfp8 moe training] mxfp8_on_device_all_to_all_v kernel [mxfp8 moe training] mxfp8_on_device_all_to_all_v kernel Sep 24, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review September 24, 2025 22:00
@danielvegamyhre danielvegamyhre force-pushed the a2a branch 2 times, most recently from 69b8f27 to 4391e16 Compare September 24, 2025 22:44
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for quantizing the kernel!

blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
sync_threads()

remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this works only in a case where number of blocks >= number of ranks.
Otherwise, we'd need to use the few blocks we have to loop over the remote_rank space.

input_scales_ptrs = input_scales_hdl.buffer_ptrs_dev
signal_pad_ptrs = input_hdl.signal_pad_ptrs_dev
dim = output.shape[1]
dim_scaling_groups = input_scales.shape[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, how big is dim_scaling_groups in general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For an input shape (M, K), it is the number of columns in the scales tensor, which will be shape (M, K//block_size), where block_size usually is 32.

I may rename the variable to be something more clear. (At minimum I will clarify in the doc string)

@danielvegamyhre danielvegamyhre merged commit 4013764 into main Sep 26, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants