Skip to content

Conversation

@youkaichao
Copy link
Member

pytorch's ipc handle format can change, and using pytorch for cuda ipc will suffer from pytorch's change. see #9815 for example.

cc @hanzhi713

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@github-actions
Copy link

github-actions bot commented Nov 6, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member Author

test failure is because we don't have a100 gpus in our ci queue.

@hanzhi713
Copy link
Contributor

@youkaichao I can confirm this work on my machine.

world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to use broadcast with device=cpu in _gather_ipc_meta but not here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you use this function, the group argument should be cpu_group passed to custom allreduce object.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see

if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is all_gather fine here but not in _gather_ipc_meta?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, that is because we met some issues with all_gather for tensors directly. here we are using all_gather_object , so it should be fine. see pytorch/pytorch#126032 for the pytorch issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Someone refactored this to return a Tensor
https://github.com/vllm-project/vllm/pull/5047/files#diff-44d9d733ee604800cbce9858a9201db1044aeff2c940fa4a0521d0c9b6541b3eL137

A better way should be returning a string, if torch bindings doesn't support int8 directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah string should be fine. #5047 aims to get rid of pybind11 so that we can release python version agnostic wheels.

@youkaichao youkaichao merged commit 4be3a45 into vllm-project:main Nov 6, 2024
22 of 26 checks passed
@youkaichao youkaichao deleted the create_shared_buffer branch November 6, 2024 06:35
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants