From e6c1765f50fd84dc1ae96abf7d803b642a5525bd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 25 Jul 2024 18:31:33 -0700 Subject: [PATCH 1/2] [Misc] Support TPU in initialize_ray_cluster --- vllm/executor/ray_utils.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index fcbfa30d7a38..b23039ff997c 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -3,7 +3,7 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest -from vllm.utils import get_ip, is_hip, is_xpu +from vllm.utils import get_ip, is_hip, is_tpu, is_xpu from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -93,6 +93,7 @@ def initialize_ray_cluster( # Placement group is already set. return + device_str = "GPU" if not is_tpu() else "TPU" # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() if current_placement_group: @@ -101,24 +102,27 @@ def initialize_ray_cluster( # Verify that we can use the placement group. gpu_bundles = 0 for bundle in bundles: - bundle_gpus = bundle.get("GPU", 0) + bundle_gpus = bundle.get(device_str, 0) if bundle_gpus > 1: raise ValueError( - "Placement group bundle cannot have more than 1 GPU.") + "Placement group bundle cannot have more than 1 " + f"{device_str}.") if bundle_gpus: gpu_bundles += 1 if parallel_config.world_size > gpu_bundles: raise ValueError( - "The number of required GPUs exceeds the total number of " - "available GPUs in the placement group.") + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group.") else: - num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0) + num_gpus_in_cluster = ray.cluster_resources().get(device_str, 0) if parallel_config.world_size > num_gpus_in_cluster: raise ValueError( - "The number of required GPUs exceeds the total number of " - "available GPUs in the cluster.") + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group.") # Create a new placement group - placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size) + placement_group_specs = ([{ + device_str: 1 + }] * parallel_config.world_size) current_placement_group = ray.util.placement_group( placement_group_specs) # Wait until PG is ready - this will block until all From 749a98055bcc86c325f0f83dc125fd430a45ffc3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 26 Jul 2024 10:24:09 -0700 Subject: [PATCH 2/2] gpu -> device --- vllm/executor/ray_utils.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index b23039ff997c..58b864070f72 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -100,22 +100,24 @@ def initialize_ray_cluster( # We are in a placement group bundles = current_placement_group.bundle_specs # Verify that we can use the placement group. - gpu_bundles = 0 + device_bundles = 0 for bundle in bundles: - bundle_gpus = bundle.get(device_str, 0) - if bundle_gpus > 1: + bundle_devices = bundle.get(device_str, 0) + if bundle_devices > 1: raise ValueError( "Placement group bundle cannot have more than 1 " f"{device_str}.") - if bundle_gpus: - gpu_bundles += 1 - if parallel_config.world_size > gpu_bundles: + if bundle_devices: + device_bundles += 1 + if parallel_config.world_size > device_bundles: raise ValueError( f"The number of required {device_str}s exceeds the total " - f"number of available {device_str}s in the placement group.") + f"number of available {device_str}s in the placement group." + f"Required number of devices: {parallel_config.world_size}. " + f"Total number of devices: {device_bundles}.") else: - num_gpus_in_cluster = ray.cluster_resources().get(device_str, 0) - if parallel_config.world_size > num_gpus_in_cluster: + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + if parallel_config.world_size > num_devices_in_cluster: raise ValueError( f"The number of required {device_str}s exceeds the total " f"number of available {device_str}s in the placement group.")