Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,13 @@ steps:
commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code

- label: Distributed Tests (2 GPUs) # 40min
Expand Down
43 changes: 43 additions & 0 deletions tests/distributed/test_node_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os

import torch.distributed as dist

from vllm.distributed.parallel_state import _node_count
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port

if __name__ == "__main__":
dist.init_process_group(backend="gloo")

rank = dist.get_rank()
world_size = dist.get_world_size()

if rank == 0:
port = get_open_port()
ip = get_ip()
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv

stateless_pg = StatelessProcessGroup.create(ip, port, rank, world_size)

for pg in [dist.group.WORLD, stateless_pg]:
test_result = _node_count(pg)

# Expected node count based on environment variable)
expected = int(os.environ.get("NUM_NODES", "1"))

assert test_result == expected, \
f"Expected {expected} nodes, got {test_result}"

if pg == dist.group.WORLD:
print(f"Node count test passed! Got {test_result} nodes "
f"when using torch distributed!")
else:
print(f"Node count test passed! Got {test_result} nodes "
f"when using StatelessProcessGroup!")
55 changes: 53 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,7 @@ def combine(self, hidden_states) -> torch.Tensor:


_WORLD: Optional[GroupCoordinator] = None
_NODE_COUNT: Optional[int] = None


def get_world_group() -> GroupCoordinator:
Expand Down Expand Up @@ -961,10 +962,13 @@ def init_distributed_environment(
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _WORLD
global _WORLD, _NODE_COUNT
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
_NODE_COUNT = _node_count(_WORLD.cpu_group)
logger.debug("Detected %d nodes in the distributed environment",
_NODE_COUNT)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size")
Expand Down Expand Up @@ -1164,6 +1168,13 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group


def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment. """
assert _NODE_COUNT is not None, (
"distributed environment is not initialized")
return _NODE_COUNT


def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
Expand All @@ -1189,10 +1200,11 @@ def destroy_model_parallel():


def destroy_distributed_environment():
global _WORLD
global _WORLD, _NODE_COUNT
if _WORLD:
_WORLD.destroy()
_WORLD = None
_NODE_COUNT = None
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()

Expand Down Expand Up @@ -1301,3 +1313,42 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
aggregated_data += rank_data

return [x == 1 for x in aggregated_data.tolist()]


def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
"""
Returns the total number of nodes in the process group.

Args:
pg: The process group to analyze

Returns:
int: The total number of nodes
"""
if isinstance(pg, ProcessGroup):
world_size = torch.distributed.get_world_size(group=pg)
else:
world_size = pg.world_size

if world_size == 1:
return 1

# Build node assignment map
node_assignment = [0] * world_size # rank -> node_id
next_node_id = 0

for current_rank in range(world_size):
if node_assignment[current_rank] != 0:
continue # Already assigned to a node

# Assign current rank to a new node
next_node_id += 1
node_assignment[current_rank] = next_node_id

# Find all ranks on the same node as current_rank
same_node_flags = in_the_same_node_as(pg, current_rank)
for other_rank, is_same_node in enumerate(same_node_flags):
if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id

return next_node_id