|  | 
| 3 | 3 | from vllm.config import ParallelConfig | 
| 4 | 4 | from vllm.logger import init_logger | 
| 5 | 5 | from vllm.sequence import ExecuteModelRequest | 
| 6 |  | -from vllm.utils import get_ip, is_hip, is_xpu | 
|  | 6 | +from vllm.utils import get_ip, is_hip, is_tpu, is_xpu | 
| 7 | 7 | from vllm.worker.worker_base import WorkerWrapperBase | 
| 8 | 8 | 
 | 
| 9 | 9 | logger = init_logger(__name__) | 
| @@ -93,32 +93,38 @@ def initialize_ray_cluster( | 
| 93 | 93 |         # Placement group is already set. | 
| 94 | 94 |         return | 
| 95 | 95 | 
 | 
|  | 96 | +    device_str = "GPU" if not is_tpu() else "TPU" | 
| 96 | 97 |     # Create placement group for worker processes | 
| 97 | 98 |     current_placement_group = ray.util.get_current_placement_group() | 
| 98 | 99 |     if current_placement_group: | 
| 99 | 100 |         # We are in a placement group | 
| 100 | 101 |         bundles = current_placement_group.bundle_specs | 
| 101 | 102 |         # Verify that we can use the placement group. | 
| 102 |  | -        gpu_bundles = 0 | 
|  | 103 | +        device_bundles = 0 | 
| 103 | 104 |         for bundle in bundles: | 
| 104 |  | -            bundle_gpus = bundle.get("GPU", 0) | 
| 105 |  | -            if bundle_gpus > 1: | 
|  | 105 | +            bundle_devices = bundle.get(device_str, 0) | 
|  | 106 | +            if bundle_devices > 1: | 
| 106 | 107 |                 raise ValueError( | 
| 107 |  | -                    "Placement group bundle cannot have more than 1 GPU.") | 
| 108 |  | -            if bundle_gpus: | 
| 109 |  | -                gpu_bundles += 1 | 
| 110 |  | -        if parallel_config.world_size > gpu_bundles: | 
|  | 108 | +                    "Placement group bundle cannot have more than 1 " | 
|  | 109 | +                    f"{device_str}.") | 
|  | 110 | +            if bundle_devices: | 
|  | 111 | +                device_bundles += 1 | 
|  | 112 | +        if parallel_config.world_size > device_bundles: | 
| 111 | 113 |             raise ValueError( | 
| 112 |  | -                "The number of required GPUs exceeds the total number of " | 
| 113 |  | -                "available GPUs in the placement group.") | 
|  | 114 | +                f"The number of required {device_str}s exceeds the total " | 
|  | 115 | +                f"number of available {device_str}s in the placement group." | 
|  | 116 | +                f"Required number of devices: {parallel_config.world_size}. " | 
|  | 117 | +                f"Total number of devices: {device_bundles}.") | 
| 114 | 118 |     else: | 
| 115 |  | -        num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0) | 
| 116 |  | -        if parallel_config.world_size > num_gpus_in_cluster: | 
|  | 119 | +        num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) | 
|  | 120 | +        if parallel_config.world_size > num_devices_in_cluster: | 
| 117 | 121 |             raise ValueError( | 
| 118 |  | -                "The number of required GPUs exceeds the total number of " | 
| 119 |  | -                "available GPUs in the cluster.") | 
|  | 122 | +                f"The number of required {device_str}s exceeds the total " | 
|  | 123 | +                f"number of available {device_str}s in the placement group.") | 
| 120 | 124 |         # Create a new placement group | 
| 121 |  | -        placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size) | 
|  | 125 | +        placement_group_specs = ([{ | 
|  | 126 | +            device_str: 1 | 
|  | 127 | +        }] * parallel_config.world_size) | 
| 122 | 128 |         current_placement_group = ray.util.placement_group( | 
| 123 | 129 |             placement_group_specs) | 
| 124 | 130 |         # Wait until PG is ready - this will block until all | 
|  | 
0 commit comments