-
Notifications
You must be signed in to change notification settings - Fork 346
[mxfp8 moe training] mxfp8_on_device_all_to_all_v kernel #3048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3048
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 3 PendingAs of commit 0f381e3 with merge base d2fae7a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
2b69b77
to
d7c1e52
Compare
cb2a361
to
12bb6ae
Compare
69b8f27
to
4391e16
Compare
4391e16
to
3d509d9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for quantizing the kernel!
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") | ||
sync_threads() | ||
|
||
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that this works only in a case where number of blocks >= number of ranks.
Otherwise, we'd need to use the few blocks we have to loop over the remote_rank space.
input_scales_ptrs = input_scales_hdl.buffer_ptrs_dev | ||
signal_pad_ptrs = input_hdl.signal_pad_ptrs_dev | ||
dim = output.shape[1] | ||
dim_scaling_groups = input_scales.shape[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, how big is dim_scaling_groups
in general?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For an input shape (M, K), it is the number of columns in the scales tensor, which will be shape (M, K//block_size)
, where block_size usually is 32.
I may rename the variable to be something more clear. (At minimum I will clarify in the doc string)
5ed9d0b
to
113810d
Compare
113810d
to
0f381e3
Compare
Summary
Changes
Using the bf16 all_to_all_v kernel from torchtitan as a starting point, this PR implements:
_mxfp8_on_device_all_to_all_v
kernel using Triton + Symmetric Memory for inter-device commsall_to_all_single_autograd
and validating SQNR is high (>28).Benchmarks
Next steps