Skip to content

Commit 1b69ea4

Browse files
youkaichaosiddvenk
authored andcommitted
[core][distributed] fix custom allreduce in pytorch 2.5 (vllm-project#9815)
Signed-off-by: youkaichao <[email protected]>
1 parent d5ffc40 commit 1b69ea4

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,20 @@ def capture(self):
191191

192192
def _get_ipc_meta(self, inp: torch.Tensor):
193193
data = inp.untyped_storage()._share_cuda_()
194+
handle = data[1]
195+
# https://github.com/pytorch/pytorch/pull/130890 changes
196+
# the binary format of the ipc handle
197+
# it starts from pytorch 2.5
198+
if len(handle) > 64:
199+
assert len(handle) == 66
200+
# only support SHAREABLE_HANDLE_VERSION = 1
201+
assert int(handle[0]) == 1
202+
# only support SHAREABLE_CUDA_MALLOC = 'c'
203+
assert handle[1] == ord("c")
204+
handle = handle[2:]
205+
# TODO: support expandable segment
194206
shard_data = (
195-
data[1], # ipc handle to base ptr
207+
handle, # ipc handle to base ptr
196208
data[3], # offset of base ptr
197209
)
198210
return self._gather_ipc_meta(shard_data)

0 commit comments

Comments
 (0)