diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d6c9ee680abf..1536759c06bd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -619,11 +619,13 @@ steps: commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code - label: Distributed Tests (2 GPUs) # 40min diff --git a/tests/distributed/test_node_count.py b/tests/distributed/test_node_count.py new file mode 100644 index 000000000000..e3c36ef5ef37 --- /dev/null +++ b/tests/distributed/test_node_count.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os + +import torch.distributed as dist + +from vllm.distributed.parallel_state import _node_count +from vllm.distributed.utils import StatelessProcessGroup +from vllm.utils import get_ip, get_open_port + +if __name__ == "__main__": + dist.init_process_group(backend="gloo") + + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + port = get_open_port() + ip = get_ip() + dist.broadcast_object_list([ip, port], src=0) + else: + recv = [None, None] + dist.broadcast_object_list(recv, src=0) + ip, port = recv + + stateless_pg = StatelessProcessGroup.create(ip, port, rank, world_size) + + for pg in [dist.group.WORLD, stateless_pg]: + test_result = _node_count(pg) + + # Expected node count based on environment variable) + expected = int(os.environ.get("NUM_NODES", "1")) + + assert test_result == expected, \ + f"Expected {expected} nodes, got {test_result}" + + if pg == dist.group.WORLD: + print(f"Node count test passed! Got {test_result} nodes " + f"when using torch distributed!") + else: + print(f"Node count test passed! Got {test_result} nodes " + f"when using StatelessProcessGroup!") diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 126160b09553..50dbbf50e9fc 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -802,6 +802,7 @@ def combine(self, hidden_states) -> torch.Tensor: _WORLD: Optional[GroupCoordinator] = None +_NODE_COUNT: Optional[int] = None def get_world_group() -> GroupCoordinator: @@ -961,10 +962,13 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK else: local_rank = rank - global _WORLD + global _WORLD, _NODE_COUNT if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) + _NODE_COUNT = _node_count(_WORLD.cpu_group) + logger.debug("Detected %d nodes in the distributed environment", + _NODE_COUNT) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") @@ -1164,6 +1168,13 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group +def get_node_count() -> int: + """Return the total number of nodes in the distributed environment. """ + assert _NODE_COUNT is not None, ( + "distributed environment is not initialized") + return _NODE_COUNT + + def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP @@ -1189,10 +1200,11 @@ def destroy_model_parallel(): def destroy_distributed_environment(): - global _WORLD + global _WORLD, _NODE_COUNT if _WORLD: _WORLD.destroy() _WORLD = None + _NODE_COUNT = None if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() @@ -1301,3 +1313,42 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], aggregated_data += rank_data return [x == 1 for x in aggregated_data.tolist()] + + +def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int: + """ + Returns the total number of nodes in the process group. + + Args: + pg: The process group to analyze + + Returns: + int: The total number of nodes + """ + if isinstance(pg, ProcessGroup): + world_size = torch.distributed.get_world_size(group=pg) + else: + world_size = pg.world_size + + if world_size == 1: + return 1 + + # Build node assignment map + node_assignment = [0] * world_size # rank -> node_id + next_node_id = 0 + + for current_rank in range(world_size): + if node_assignment[current_rank] != 0: + continue # Already assigned to a node + + # Assign current rank to a new node + next_node_id += 1 + node_assignment[current_rank] = next_node_id + + # Find all ranks on the same node as current_rank + same_node_flags = in_the_same_node_as(pg, current_rank) + for other_rank, is_same_node in enumerate(same_node_flags): + if is_same_node and node_assignment[other_rank] == 0: + node_assignment[other_rank] = next_node_id + + return next_node_id