-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[Kernels] Add an inductor pass to rewrite and fuse collective communication ops with gemms #9886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
This pull request has merge conflicts that must be resolved before it can be |
b3200f8
to
5183999
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking forward to this one!
0a1f637
to
1c9d79c
Compare
e164973
to
1683f80
Compare
This pull request has merge conflicts that must be resolved before it can be |
1683f80
to
34de3a4
Compare
This pull request has merge conflicts that must be resolved before it can be |
ef2be0d
to
7ebd94c
Compare
This pull request has merge conflicts that must be resolved before it can be |
This pull request has merge conflicts that must be resolved before it can be |
d713a7d
to
7e2c490
Compare
device_group = group.device_group | ||
rank = group.rank_in_group | ||
|
||
if use_flux: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we maybe use a better abstraction than if statements based on use_flux
?
fused_node = graph.call_function(fused_gemm_func, | ||
kwargs=kwargs) | ||
|
||
graph.inserting_after(fused_node) | ||
result_node_new = graph.call_function(operator.getitem, | ||
(fused_node, 0)) | ||
residual_node_new = graph.call_function( | ||
operator.getitem, (fused_node, 1)) | ||
my_residual_node_new = graph.call_function( | ||
operator.getitem, (fused_node, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think multi-output match has a utility that emits a function and tuple accessors.
res_replacements.append(residual_node_new) | ||
my_res_replacements.append(my_residual_node_new) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we save all of the residuals instead of just the previous one?
if gemm_1 is None or gemm_2 is None: | ||
raise ValueError("Missing 'val' in gemm weights meta data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be simpler if you just do meta["val"]
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Bill Nell <[email protected]>
7e2c490
to
590b3d2
Compare
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you! |
Add an inductor pass to rewrite and fuse collective communication ops with gemms
See #9883 for version that includes llama hacks.
TODO:
torch._inductor.ir.ExternKernel.__str__
pytorch/pytorch#139501cc @tlrmchlsmth , @ProExpertProg , @SageMoore , @youkaichao
Requires a special config to run:
Some benchmark results: