diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index 57b7917b5853..31ae30ad302b 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -8,7 +8,7 @@ vLLM supports Google Cloud TPUs using PyTorch XLA. Requirements ------------ -* Google Cloud TPU VM (single host) +* Google Cloud TPU VM (single & multi host) * TPU versions: v5e, v5p, v4 * Python: 3.10 diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 16525887cf4e..16ec84b43cac 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,3 +1,4 @@ +import ray import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -18,9 +19,15 @@ def __init__(self, group: ProcessGroup): return self.disabled = False - local_rank = dist.get_rank(group) - world_size = dist.get_world_size(group) - pjrt.initialize_multiprocess(local_rank, world_size) + # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node + # must be used together. Therefore, the local rank and world size can + # be simply calculated as follows. + global_rank = dist.get_rank(group) + global_world_size = dist.get_world_size(group) + num_nodes = len(ray.nodes()) + local_world_size = global_world_size // num_nodes + local_rank = global_rank % local_world_size + pjrt.initialize_multiprocess(local_rank, local_world_size) xr._init_world_size_ordinal() def all_reduce(self, x: torch.Tensor) -> torch.Tensor: