diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 22487f5524dd..c2140fbffec9 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -4,4 +4,5 @@ # Dependencies for TPU # Currently, the TPU backend uses a nightly version of PyTorch XLA. # You can install the dependencies in Dockerfile.tpu. +ray triton # To avoid import errors diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b83a83bb177d..c53a2f91b89d 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -55,8 +55,8 @@ class PallasMetadata(AttentionMetadata): # Currently, input sequences can only contain all prefills # or all decoding. - block_tables: Optional[torch.Tensor] - context_lens: Optional[torch.Tensor] + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None @property def prefill_metadata(self) -> Optional["PallasMetadata"]: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 48d530589221..004348d4c49a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -394,8 +394,14 @@ def _get_executor_cls(cls, from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor elif engine_config.device_config.device_type == "tpu": - from vllm.executor.tpu_executor import TPUExecutor - executor_class = TPUExecutor + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_tpu_executor import RayTPUExecutor + executor_class = RayTPUExecutor + else: + assert distributed_executor_backend is None + from vllm.executor.tpu_executor import TPUExecutor + executor_class = TPUExecutor elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py new file mode 100644 index 000000000000..7048d4798072 --- /dev/null +++ b/vllm/executor/ray_tpu_executor.py @@ -0,0 +1,313 @@ +import asyncio +import os +from collections import defaultdict +from itertools import islice, repeat +from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, + Union) + +import vllm.envs as envs +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.executor.tpu_executor import TPUExecutor +from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + get_vllm_instance_id, make_async) + +if ray is not None: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + + +class RayTPUExecutor(TPUExecutor): + + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + # Updated by implementations that require additional args to be passed + # to the _run_workers execute_model call + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} + + super().__init__(*args, **kwargs) + + def _init_executor(self) -> None: + assert self.parallel_config.distributed_executor_backend == "ray" + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel TPU workers. + self._init_workers_ray(placement_group) + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Create the workers. + driver_ip = get_ip() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("TPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + + assert self.speculative_config is None + worker_module_name = "vllm.worker.tpu_worker" + worker_class_name = "TPUWorker" + + worker = ray.remote( + num_cpus=0, + resources={"TPU": 1}, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + else: + # Else, added to the list of workers. + self.workers.append(worker) + + if self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any TPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "TPU node.") + + # Get the set of TPU IDs used on each node. + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) + + node_workers = defaultdict(list) + for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + }, ) for _ in worker_node_and_gpu_ids] + self._run_workers("update_environment_variables", + all_args=all_args_to_update_environment_variables) + + if len(node_workers) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_method("execute_model", + execute_model_req) + + def _run_workers( + self, + method: str, + *args, + async_run_remote_workers_only: bool = False, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + use_dummy_driver: bool = False, + max_concurrent_workers: Optional[int] = None, + use_ray_compiled_dag: bool = False, + **kwargs, + ) -> Any: + """Runs the given method on all workers. Can be used in the following + ways: + + - async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than blocking + on the results. + - args/kwargs: All workers share the same args/kwargs + - all_args/all_kwargs: args/kwargs for each worker are specified + individually + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + count = len(self.workers) + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, 1, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, 1, None) + + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] + + if async_run_remote_workers_only: + # Just return futures + return ray_worker_outputs + + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) + else: + assert self.driver_dummy_worker is not None + driver_worker_output = ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return [driver_worker_output] + ray_worker_outputs + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_blocks = self._run_workers("determine_num_available_blocks", ) + num_tpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + return num_tpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_remote_workers_only=True, + **self.extra_execute_model_run_workers_kwargs) + + # Only the driver worker returns the sampling results. + return self._driver_execute_model(execute_model_req) + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + + +class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_method = make_async(self.driver_worker.execute_method) + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_method("execute_model", + execute_model_req) + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 8a8b412db673..e5bb101fc7df 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,6 +1,7 @@ import time from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from unittest.mock import patch import numpy as np import torch @@ -45,6 +46,7 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int best_of: List[int] seq_groups: List[List[int]] + virtual_engine: int = 0 def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -55,6 +57,9 @@ def as_broadcastable_tensor_dict( "t": self.t, "p": self.p, "num_samples": self.num_samples, + "best_of": self.best_of, + "seq_groups": self.seq_groups, + "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -113,16 +118,30 @@ def __init__( def load_model(self) -> None: self.device = self.device_config.device - model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - parallel_config=self.parallel_config, - cache_config=self.cache_config, - scheduler_config=self.scheduler_config, - multimodal_config=self.multimodal_config, - lora_config=None, - ) + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xm.get_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + parallel_config=self.parallel_config, + cache_config=self.cache_config, + scheduler_config=self.scheduler_config, + multimodal_config=self.multimodal_config, + lora_config=None, + ) model = model.eval() xm.wait_device_ops() @@ -463,10 +482,11 @@ def make_model_input_from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend) return model_input + @torch.no_grad() def execute_model( self, model_input: ModelInputForTPU, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + kv_caches: Optional[List[Any]], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> List[SamplerOutput]: diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 03011e03058d..c88aba7ae08c 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -70,13 +70,13 @@ def __init__( def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" - self.device = xm.xla_device() - self.device_config.device = self.device torch.set_grad_enabled(False) torch.set_default_dtype(self.model_config.dtype) - # NOTE(woosuk): This is just a hack to initialize the TP group. - # This cannot perform the actual communication ops. + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. init_distributed_environment( world_size=self.parallel_config.world_size, rank=self.rank, @@ -88,6 +88,11 @@ def init_device(self) -> None: self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + # Set random seed. set_random_seed(self.model_config.seed) xm.set_rng_state(self.model_config.seed, self.device) @@ -200,8 +205,7 @@ def get_cache_block_size_bytes(self) -> int: @property def do_metadata_broadcast(self) -> bool: - # TODO(woosuk): Support TP. - return False + return self.parallel_config.tensor_parallel_size > 1 @property def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: