66import torch .distributed
77
88from vllm .distributed .communication_op import ( # noqa
9- graph_capture , tensor_model_parallel_all_reduce )
9+ tensor_model_parallel_all_reduce )
1010from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
1111from vllm .distributed .device_communicators .pynccl_wrapper import NCCLLibrary
1212from vllm .distributed .parallel_state import (ensure_model_parallel_initialized ,
13+ get_world_group , graph_capture ,
1314 init_distributed_environment )
1415from vllm .utils import update_environment_variables
1516
@@ -53,7 +54,8 @@ def wrapped_fn(env):
5354
5455@worker_fn_wrapper
5556def worker_fn ():
56- pynccl_comm = PyNcclCommunicator ()
57+ pynccl_comm = PyNcclCommunicator (get_world_group ().cpu_group ,
58+ device = get_world_group ().device )
5759 tensor = torch .ones (16 , 1024 , 1024 ,
5860 dtype = torch .float32 ).cuda (pynccl_comm .rank )
5961 with pynccl_comm .change_state (enable = True ):
@@ -129,7 +131,8 @@ def test_pynccl_multiple_allreduce_with_vllm():
129131def worker_fn_with_cudagraph ():
130132 with torch .no_grad ():
131133 graph = torch .cuda .CUDAGraph ()
132- pynccl_comm = PyNcclCommunicator ()
134+ pynccl_comm = PyNcclCommunicator (get_world_group ().cpu_group ,
135+ device = get_world_group ().device )
133136 # run something in the default stream to initialize torch engine
134137 a = torch .ones ((4 , 4 ), device = f'cuda:{ pynccl_comm .rank } ' )
135138 torch .cuda .synchronize ()
@@ -154,7 +157,8 @@ def test_pynccl_with_cudagraph():
154157
155158@worker_fn_wrapper
156159def send_recv_worker_fn ():
157- pynccl_comm = PyNcclCommunicator ()
160+ pynccl_comm = PyNcclCommunicator (get_world_group ().cpu_group ,
161+ device = get_world_group ().device )
158162 if pynccl_comm .rank == 0 :
159163 tensor = torch .ones (16 , 1024 , 1024 ,
160164 dtype = torch .float32 ).cuda (pynccl_comm .rank )
0 commit comments