diff --git a/Dockerfile.tpu b/Dockerfile.tpu new file mode 100644 index 000000000000..931c844c08dc --- /dev/null +++ b/Dockerfile.tpu @@ -0,0 +1,19 @@ +ARG NIGHTLY_DATE="20240601" +ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" + +FROM $BASE_IMAGE + +WORKDIR /workspace +COPY . /workspace/vllm + +ENV VLLM_TARGET_DEVICE="tpu" +# Install aiohttp separately to avoid build errors. +RUN pip install aiohttp +# Install the TPU and Pallas dependencies. +RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html +RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + +# Build vLLM. +RUN cd /workspace/vllm && python setup.py develop + +CMD ["/bin/bash"] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 1a41b66b3882..17edb7515964 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -189,7 +189,7 @@ def run_to_completion(profile_dir: Optional[str] = None): "--device", type=str, default="cuda", - choices=["cuda", "cpu"], + choices=["cuda", "cpu", "tpu"], help='device type for vLLM execution, supporting CUDA and CPU.') parser.add_argument('--block-size', type=int, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 90f7433e0ae2..07b2f85410e3 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -346,7 +346,7 @@ def main(args: argparse.Namespace): "--device", type=str, default="cuda", - choices=["cuda", "cpu"], + choices=["cuda", "cpu", "tpu"], help='device type for vLLM execution, supporting CUDA and CPU.') parser.add_argument( "--enable-prefix-caching", diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst new file mode 100644 index 000000000000..3627600e1f23 --- /dev/null +++ b/docs/source/getting_started/tpu-installation.rst @@ -0,0 +1,75 @@ +.. _installation_tpu: + +Installation with TPU +===================== + +vLLM supports Google Cloud TPUs using PyTorch XLA. + +Requirements +------------ + +* Google Cloud TPU VM (single host) +* TPU versions: v5e, v5p, v4 +* Python: 3.10 + +Installation options: + +1. :ref:`Build a docker image with Dockerfile `. +2. :ref:`Build from source `. + +.. _build_docker_tpu: + +Build a docker image with :code:`Dockerfile.tpu` +------------------------------------------------ + +`Dockerfile.tpu `_ is provided to build a docker image with TPU support. + +.. code-block:: console + + $ docker build -f Dockerfile.tpu -t vllm-tpu . + + +You can run the docker image with the following command: + +.. code-block:: console + + $ # Make sure to add `--privileged --net host --shm-size=16G`. + $ docker run --privileged --net host --shm-size=16G -it vllm-tpu + + +.. _build_from_source_tpu: + +Build from source +----------------- + +You can also build and install the TPU backend from source. + +First, install the dependencies: + +.. code-block:: console + + $ # (Recommended) Create a new conda environment. + $ conda create -n myenv python=3.10 -y + $ conda activate myenv + + $ # Clean up the existing torch and torch-xla packages. + $ pip uninstall torch torch-xla -y + + $ # Install PyTorch and PyTorch XLA. + $ export DATE="+20240601" + $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl + $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl + + $ # Install JAX and Pallas. + $ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html + $ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + $ # Install other build dependencies. + $ pip install packaging aiohttp + + +Next, build vLLM from source. This will only take a few seconds: + +.. code-block:: console + + $ VLLM_TARGET_DEVICE="tpu" python setup.py develop diff --git a/docs/source/index.rst b/docs/source/index.rst index 807251d02974..b7c0d5b88007 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -63,8 +63,9 @@ Documentation getting_started/installation getting_started/amd-installation - getting_started/neuron-installation getting_started/cpu-installation + getting_started/neuron-installation + getting_started/tpu-installation getting_started/quickstart getting_started/debugging getting_started/examples/examples_index diff --git a/requirements-tpu.txt b/requirements-tpu.txt new file mode 100644 index 000000000000..22487f5524dd --- /dev/null +++ b/requirements-tpu.txt @@ -0,0 +1,7 @@ +# Common dependencies +-r requirements-common.txt + +# Dependencies for TPU +# Currently, the TPU backend uses a nightly version of PyTorch XLA. +# You can install the dependencies in Dockerfile.tpu. +triton # To avoid import errors diff --git a/setup.py b/setup.py index 53a697232b44..12e5c34568f7 100644 --- a/setup.py +++ b/setup.py @@ -206,9 +206,9 @@ def build_extensions(self) -> None: def _is_cuda() -> bool: - return VLLM_TARGET_DEVICE == "cuda" \ - and torch.version.cuda is not None \ - and not _is_neuron() + has_cuda = torch.version.cuda is not None + return (VLLM_TARGET_DEVICE == "cuda" and has_cuda + and not (_is_neuron() or _is_tpu())) def _is_hip() -> bool: @@ -225,10 +225,18 @@ def _is_neuron() -> bool: return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron" +def _is_tpu() -> bool: + return VLLM_TARGET_DEVICE == "tpu" + + def _is_cpu() -> bool: return VLLM_TARGET_DEVICE == "cpu" +def _build_custom_ops() -> bool: + return _is_cuda() or _is_hip() or _is_cpu() + + def _install_punica() -> bool: return envs.VLLM_INSTALL_PUNICA_KERNELS @@ -325,6 +333,8 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"+neuron{neuron_version_str}" + elif _is_tpu(): + version += "+tpu" elif _is_cpu(): version += "+cpu" else: @@ -372,6 +382,8 @@ def _read_requirements(filename: str) -> List[str]: requirements = _read_requirements("requirements-rocm.txt") elif _is_neuron(): requirements = _read_requirements("requirements-neuron.txt") + elif _is_tpu(): + requirements = _read_requirements("requirements-tpu.txt") elif _is_cpu(): requirements = _read_requirements("requirements-cpu.txt") else: @@ -385,7 +397,7 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) -if not _is_neuron(): +if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) if _install_punica(): @@ -428,6 +440,6 @@ def _read_requirements(filename: str) -> List[str]: extras_require={ "tensorizer": ["tensorizer>=2.9.0"], }, - cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, + cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, package_data=package_data, ) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py new file mode 100644 index 000000000000..b203c5ec54c9 --- /dev/null +++ b/vllm/attention/backends/pallas.py @@ -0,0 +1,232 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. +import torch_xla.experimental.dynamo_set_buffer_donor + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "PallasMetadata": + return PallasMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + raise NotImplementedError("swap_blocks is not implemented.") + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + # TODO(woosuk): Implement this. + raise NotImplementedError("copy_blocks is not implemented.") + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] + context_lens: Optional[torch.Tensor] + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + assert self.block_tables is None + assert self.context_lens is None + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() + if not tpu_type.endswith("lite"): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], + attn_metadata: PallasMetadata, + kv_scale: float = 1.0, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + key_cache = [num_kv_heads, num_blocks, block_size, head_size] + value_cache = [num_kv_heads, num_blocks, block_size, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + assert kv_scale == 1.0 + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0] is not None: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention requires [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Decoding run. + assert kv_cache is not None + + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + if self.megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = self.megacore_mode + + # NOTE(woosuk): A temporary workaround to avoid the error: + # "xla::paged_attention() Expected a value of type 'str' for + # argument 'megacore_mode' but instead found type 'NoneType'." + if megacore_mode is not None: + output = torch.ops.xla.paged_attention( + query.squeeze(dim=1), + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + ) + else: + output = torch.ops.xla.paged_attention( + query.squeeze(dim=1), + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + ) + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7253483f9a0b..3f0e29c73e0c 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -7,7 +7,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger -from vllm.utils import is_cpu, is_hip +from vllm.utils import is_cpu, is_hip, is_tpu logger = init_logger(__name__) @@ -18,6 +18,7 @@ class _Backend(enum.Enum): ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() + PALLAS = enum.auto() @lru_cache(maxsize=None) @@ -66,6 +67,10 @@ def get_attn_backend( "Please make sure --enforce-eager is set.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend + elif backend == _Backend.PALLAS: + logger.info("Using Pallas backend.") + from vllm.attention.backends.pallas import PallasAttentionBackend + return PallasAttentionBackend else: raise ValueError("Invalid attention backend.") @@ -80,7 +85,6 @@ def which_attn_to_use( block_size: int, ) -> _Backend: """Returns which flash attention backend to use.""" - # Default case. selected_backend = _Backend.FLASH_ATTN @@ -100,6 +104,11 @@ def which_attn_to_use( logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA + if is_tpu(): + if selected_backend != _Backend.PALLAS: + logger.info("Cannot use %s backend on TPU.", selected_backend) + return _Backend.PALLAS + if is_hip(): # AMD GPUs. selected_backend = (_Backend.ROCM_FLASH if selected_backend diff --git a/vllm/config.py b/vllm/config.py index 7ffb93c19ede..a0a0c03ab0df 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron +from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron, is_tpu if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -732,6 +732,8 @@ def __init__(self, device: str = "auto") -> None: # Automated device type detection if is_neuron(): self.device_type = "neuron" + elif is_tpu(): + self.device_type = "tpu" elif is_cpu(): self.device_type = "cpu" else: @@ -745,6 +747,8 @@ def __init__(self, device: str = "auto") -> None: # Some device types require processing inputs on CPU if self.device_type in ["neuron"]: self.device = torch.device("cpu") + elif self.device_type in ["tpu"]: + self.device = None else: # Set device with device type self.device = torch.device(self.device_type) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cd29db7d7a9e..227de5475b94 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -504,7 +504,7 @@ def add_cli_args( parser.add_argument("--device", type=str, default=EngineArgs.device, - choices=["auto", "cuda", "neuron", "cpu"], + choices=["auto", "cuda", "neuron", "cpu", "tpu"], help='Device type for vLLM execution.') # Related to Vision-language models such as llava diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index aa1f07b5bdc2..943402c865bd 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -375,6 +375,9 @@ def from_engine_args( if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync + elif engine_config.device_config.device_type == "tpu": + from vllm.executor.tpu_executor import TPUExecutorAsync + executor_class = TPUExecutorAsync elif engine_config.device_config.device_type == "cpu": assert distributed_executor_backend is None, ( "Distributed execution is not supported with the CPU backend.") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4f56bbd5c2dc..ea754758492f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -341,6 +341,9 @@ def from_engine_args( if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor + elif engine_config.device_config.device_type == "tpu": + from vllm.executor.tpu_executor import TPUExecutor + executor_class = TPUExecutor elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor diff --git a/vllm/envs.py b/vllm/envs.py index f0513b9af276..f03b69f4b886 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -27,6 +27,7 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/" VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" VLLM_IMAGE_FETCH_TIMEOUT: int = 5 @@ -217,6 +218,11 @@ # Default is 5 seconds "VLLM_IMAGE_FETCH_TIMEOUT": lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), + + # Path to the XLA persistent cache directory. + # Only used for XLA devices such as TPUs. + "VLLM_XLA_CACHE_PATH": + lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"), } # end-env-vars-definition diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py new file mode 100644 index 000000000000..7061ad85f88c --- /dev/null +++ b/vllm/executor/tpu_executor.py @@ -0,0 +1,101 @@ +from typing import List, Set, Tuple + +import torch + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) + +logger = init_logger(__name__) + + +class TPUExecutor(ExecutorBase): + + def _init_executor(self) -> None: + assert not self.scheduler_config.chunked_prefill_enabled, ( + "Chunked prefill is not yet supported for TPU backend") + assert not self.speculative_config, ( + "Speculative decoding is not yet supported for TPU backend") + if self.model_config.dtype in (torch.float16, torch.float32): + logger.warning( + "The TPU backend currently does not support %s. " + "Using bfloat16 instead.", self.model_config.dtype) + self.model_config.dtype = torch.bfloat16 + + # Instantiate the worker and load the model to the device. + self._init_worker() + + def _init_worker(self): + from vllm.worker.tpu_worker import TPUWorker + + assert self.parallel_config.world_size == 1, ( + "TPUExecutor currently only supports a single TPU chip.") + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = TPUWorker( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + self.cache_config, + self.load_config, + self.vision_language_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + ) + self.driver_worker.init_device() + self.driver_worker.load_model() + + def initialize_cache( + self, + num_gpu_blocks: int, + num_cpu_blocks: int, + ) -> None: + """Initialize the KV cache by invoking the underlying worker.""" + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def execute_model( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError("LoRA is not implemented for TPU backend.") + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError("LoRA is not implemented for TPU backend.") + + def list_loras(self) -> Set[int]: + raise NotImplementedError("LoRA is not implemented for TPU backend.") + + def check_health(self) -> None: + # TPUExecutor will always be healthy as long as it's running. + return + + +class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + sexecute_model_req: ExecuteModelRequest, + ) -> SamplerOutput: + output = await make_async(self.driver_worker.execute_model + )(sexecute_model_req) + return output diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 1d49213cd4ab..56aa629ae345 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,6 +1,6 @@ import torch.nn as nn -from vllm.utils import is_cpu, is_hip +from vllm.utils import is_cpu, is_hip, is_tpu class CustomOp(nn.Module): @@ -56,5 +56,7 @@ def dispatch_forward(self): return self.forward_hip elif is_cpu(): return self.forward_cpu + elif is_tpu(): + return self.forward_tpu else: return self.forward_cuda diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d2652106b844..792c4729355a 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -28,6 +28,7 @@ import torch.nn as nn from vllm.model_executor.custom_op import CustomOp +from vllm.utils import is_tpu def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: return x.flatten(-2) +def _apply_rotary_emb( + x: torch.Tensor, + freqs_cis: torch.Tensor, +) -> torch.Tensor: + x_ = torch.view_as_complex( + torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1)) + x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) + x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) + x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], + -1).transpose(1, 2) + return x_out + + class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" @@ -64,8 +78,14 @@ def __init__( self.dtype = dtype cache = self._compute_cos_sin_cache() - cache = cache.to(dtype) - self.register_buffer("cos_sin_cache", cache, persistent=False) + self.use_native2 = is_tpu() and is_neox_style + if not self.use_native2: + cache = cache.to(dtype) + self.register_buffer("cos_sin_cache", cache, persistent=False) + else: + cos, sin = cache.chunk(2, dim=-1) + freqs_cis = cos + 1j * sin + self.register_buffer("freqs_cis", freqs_cis, persistent=False) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" @@ -100,7 +120,11 @@ def forward_native( key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - """PyTorch-native implementation equivalent to forward().""" + """A PyTorch-native implementation equivalent to forward(). + + This method mimics the implementation of the custom CUDA kernel + used in `forward_cuda()`. + """ query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) @@ -138,6 +162,42 @@ def forward_native( key = key.flatten(-2) return query, key + def forward_native2( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Another PyTorch-native implementation of forward(). + + This method might perform better than `forward_native()` when compiled. + """ + if positions.dim() == 1: + batch_size = 1 + seq_len = positions.shape[0] + else: + batch_size, seq_len = positions.shape + if offsets is not None: + positions = positions + offsets + freqs_cis = self.freqs_cis.index_select(0, positions.flatten()) + freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1) + + query_shape = query.shape + query = query.view(batch_size, seq_len, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, freqs_cis) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(batch_size, seq_len, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, freqs_cis) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + def forward_cuda( self, positions: torch.Tensor, @@ -161,6 +221,17 @@ def forward_cuda( self.cos_sin_cache, self.is_neox_style) return query, key + def forward_tpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + forward_fn = (self.forward_native2 + if self.use_native2 else self.forward_native) + return forward_fn(positions, query, key, offsets) + def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 9c2eaee2eda5..f4c3dcbace24 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -34,6 +34,7 @@ pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models.vlm_base import VisionLanguageModelBase from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import is_tpu logger = init_logger(__name__) @@ -227,12 +228,26 @@ def _get_weights_iterator( if self.load_config.load_format == LoadFormat.NPCACHE: # Currently np_cache only support *.bin checkpoints assert use_safetensors is False - return np_cache_weights_iterator(model_name_or_path, - self.load_config.download_dir, - hf_folder, hf_weights_files) - if use_safetensors: - return safetensors_weights_iterator(hf_weights_files) - return pt_weights_iterator(hf_weights_files) + weights_iterator = np_cache_weights_iterator( + model_name_or_path, self.load_config.download_dir, hf_folder, + hf_weights_files) + elif use_safetensors: + weights_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) + + if is_tpu(): + # In PyTorch XLA, we should call `xm.mark_step` frequently so that + # not too many ops are accumulated in the XLA program. + import torch_xla.core.xla_model as xm + + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) + return weights_iterator def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, diff --git a/vllm/utils.py b/vllm/utils.py index 54d446b23350..af585929d1a0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -146,6 +146,15 @@ def is_neuron() -> bool: return transformers_neuronx is not None +@lru_cache(maxsize=None) +def is_tpu() -> bool: + try: + import libtpu + except ImportError: + libtpu = None + return libtpu is not None + + @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" @@ -546,6 +555,11 @@ def maybe_expand_dim(tensor: torch.Tensor, return tensor +def get_dtype_size(dtype: torch.dtype) -> int: + """Get the size of the data type in bytes.""" + return torch.tensor([], dtype=dtype).element_size() + + def merge_dicts(dict1: Dict[Any, List[Any]], dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]: """Merge 2 dicts that have key -> List of items. diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 2f0e59f7ae7c..341b177d4af2 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,7 +6,8 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, + is_pin_memory_available) logger = init_logger(__name__) @@ -108,9 +109,5 @@ def get_cache_block_size( dtype = model_config.dtype else: dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - dtype_size = _get_dtype_size(dtype) + dtype_size = get_dtype_size(dtype) return dtype_size * total - - -def _get_dtype_size(dtype: torch.dtype) -> int: - return torch.tensor([], dtype=dtype).element_size() diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py new file mode 100644 index 000000000000..5003d3b0ca44 --- /dev/null +++ b/vllm/worker/tpu_model_runner.py @@ -0,0 +1,525 @@ +import time +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SamplerOutput, SequenceGroupMetadata, + SequenceOutput) +from vllm.utils import make_tensor_with_pad + +logger = init_logger(__name__) + +_PAD_SLOT_ID = 0 # FIXME(woosuk) + + +class TPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + vision_language_config: Optional[VisionLanguageConfig] = None, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.vision_language_config = vision_language_config + + self.block_size = self.cache_config.block_size + self.max_num_blocks_per_seq = (self.model_config.max_model_len // + self.block_size) + self.block_tables = np.zeros( + (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), + dtype=np.int32) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + False, + ) + + def load_model(self) -> None: + self.device = self.device_config.device + + model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + parallel_config=self.parallel_config, + cache_config=self.cache_config, + scheduler_config=self.scheduler_config, + vision_language_config=self.vision_language_config, + lora_config=None, + ) + xm.wait_device_ops() + + model = ModelWrapper(model) + self.model = torch.compile(model, backend="openxla", fullgraph=True) + + def _dummy_run( + self, + batch_size: int, + seq_len: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + is_prompt: bool, + ) -> None: + if is_prompt: + seq_len = (seq_len + 15) // 16 * 16 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + block_tables=None, + context_lens=None, + ) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + else: + assert seq_len == 1 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_seq), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size * seq_len, + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + ) + t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + + # Dummy run. + self.model(token_ids, position_ids, kv_caches, attn_metadata, + input_lens, t, p) + + def warmup_model( + self, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> None: + # Prefill + logger.info("Compiling the model with different input shapes...") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while True: + self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if seq_len >= self.model_config.max_model_len: + break + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + seq_len = seq_len * 2 + + end = time.time() + logger.info("Compilation for prefill done in %.2f s.", end - start) + + # Decode + start = time.time() + seq_len = 1 + batch_size = 1 + while True: + self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode done in %.2f s.", end - start) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ): + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + prompt_lens: List[int] = [] + slot_mapping: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + # Could include output tokens when a request is preempted. + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.append(prompt_tokens) + input_positions.append(list(range(prompt_len))) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + slot_mapping.append([]) + for i in range(prompt_len): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + assert len(prompt_lens) > 0 + num_prefills = len(prompt_lens) + num_prefill_tokens = sum(prompt_lens) + + # Add paddings to make the shape [batch_size, max_prompt_len] where + # max_prompt_len is smallest power of 2 that is greater than or equal + # to the maximum prompt length. + # We need the 2D input shape because the Pallas FlashAttention kernel + # does not support packed 1D inputs. + # We pad the seq_len to powers of 2 to reduce the compilation overhead. + max_prompt_len = _get_padded_prefill_len(max(prompt_lens)) + input_tokens = make_tensor_with_pad(input_tokens, + max_prompt_len, + pad=0, + dtype=torch.int32, + device=self.device) + input_positions = make_tensor_with_pad(input_positions, + max_prompt_len, + pad=0, + dtype=torch.int32, + device=self.device) + slot_mapping = make_tensor_with_pad(slot_mapping, + max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.int64, + device=self.device) + prompt_lens = torch.tensor(prompt_lens, + dtype=torch.int32, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, # NOTE: This is not used. + num_decode_tokens=0, + slot_mapping=slot_mapping, + block_tables=None, + context_lens=None, + ) + return input_tokens, input_positions, attn_metadata, prompt_lens + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ): + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + context_lens: List[int] = [] + num_seq_groups = len(seq_group_metadata_list) + batch_size = _get_padded_batch_size(num_seq_groups) + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + assert not seq_group_metadata.is_prompt + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + context_lens.append(seq_len) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + self.block_tables[i, :len(block_table)] = block_table + + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + + num_paddings = batch_size - num_seq_groups + input_tokens = input_tokens + [[0]] * num_paddings + input_positions = input_positions + [[0]] * num_paddings + slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings + context_lens = context_lens + [0] * num_paddings + + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device=self.device) + context_lens = torch.tensor(context_lens, + dtype=torch.int32, + device=self.device) + block_tables = torch.tensor(self.block_tables[:batch_size], + dtype=torch.int32, + device=self.device) + input_lens = torch.tensor([1] * batch_size, + dtype=torch.int32, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + ) + return input_tokens, input_positions, attn_metadata, input_lens + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + padded_batch_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + t = [] + p = [] + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.sampling_params is not None + sampling_params = seq_group_metadata.sampling_params + + t.append(sampling_params.temperature + if sampling_params.temperature >= 1e-5 else 1e-5) + p.append(sampling_params.top_p) + num_paddings = padded_batch_size - len(seq_group_metadata_list) + t += [1.0] * num_paddings + p += [1.0] * num_paddings + + t = torch.tensor(t, dtype=torch.float32, device=self.device) + p = torch.tensor(p, dtype=torch.float32, device=self.device) + return t, p + + def prepare_inputs( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ): + assert seq_group_metadata_list is not None + assert len(seq_group_metadata_list) > 0 + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + if seq_group_metadata_list[0].is_prompt: + inputs = self._prepare_prompt(seq_group_metadata_list) + else: + inputs = self._prepare_decode(seq_group_metadata_list) + padded_batch_size = inputs[0].shape[0] + sample_inputs = self._prepare_sample(seq_group_metadata_list, + padded_batch_size) + return inputs + sample_inputs + + def _execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> List[CompletionSequenceGroupOutput]: + inputs = self.prepare_inputs(seq_group_metadata_list) + next_token_ids = self.model(inputs[0], inputs[1], kv_caches, + *inputs[2:]) + next_token_ids = next_token_ids.cpu().tolist() + + i = 0 + sampler_outputs = [] + for seq_group_metadata in seq_group_metadata_list: + seq_outputs = [] + seq_ids = list(seq_group_metadata.seq_data.keys()) + for seq_id in seq_ids: + next_token_id = next_token_ids[i] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: Logprob(0.0)})) + i += 1 + sampler_outputs.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + return sampler_outputs + + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> SamplerOutput: + assert seq_group_metadata_list is not None + if seq_group_metadata_list[0].is_prompt: + # NOTE(woosuk): To reduce the compilation time, we only compile the + # prefill inputs with batch size 1. Because the scheduler is not + # aware of this limitation, we need to handle batch size > 1 + # internally by calling the model multiple times and concatenating + # the outputs. + # FIXME(woosuk): This is a temporary hack to not change the existing + # scheduler. We need to fix this in the future. + sampler_outputs = [] + for seq_group_metadata in seq_group_metadata_list: + sampler_outputs += self._execute_model([seq_group_metadata], + kv_caches) + else: + sampler_outputs = self._execute_model(seq_group_metadata_list, + kv_caches) + return SamplerOutput(sampler_outputs) + + +class ModelWrapper(nn.Module): + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model.eval() + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], + attn_metadata: AttentionMetadata, + input_lens: torch.Tensor, + t: torch.Tensor, + p: torch.Tensor, + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + attn_metadata: The Pallas attention metadata. + input_lens: The actual input lengths of shape [batch_size]. + t: The sampling temperature of shape [batch_size]. + p: The top-p probability of shape [batch_size]. + """ + batch_size, seq_len = token_ids.shape + # Calculate the positions to sample from. + base_indicies = torch.arange( + batch_size, dtype=torch.int32, device=input_lens.device) * seq_len + logits_indices = base_indicies + input_lens - 1 + + # FIXME(woosuk): This is a temporary hack to avoid using the existing + # sampler and sampling metadata. + sampling_metadata = SamplingMetadata( + seq_groups=[], + selected_token_indices=logits_indices, + categorized_sample_indices={}, + num_prompts=attn_metadata.num_prefills, + ) + + # Skip this in memory profiling at initialization. + if kv_caches[0][0] is not None: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + logits = logits / t.unsqueeze(dim=1) + # FIXME(woosuk): Disabled top-p sampling since it's too slow. + # logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + # FIXME(woosuk): best_of > 1 is not supported. + next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1) + return next_token_ids + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +def _get_padded_batch_size(batch_size: int) -> int: + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + elif batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 + + +def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + logits_sorted = torch.sort(logits, dim=-1, descending=True).values + sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) + cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) + cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) + logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) + return logits diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py new file mode 100644 index 000000000000..04576015dadb --- /dev/null +++ b/vllm/worker/tpu_worker.py @@ -0,0 +1,198 @@ +import os +from typing import List, Optional, Tuple + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.worker.tpu_model_runner import TPUModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + +logger = init_logger(__name__) + + +class TPUWorker(LoraNotSupportedWorkerBase): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + vision_language_config: Optional[VisionLanguageConfig], + local_rank: int, + rank: int, + distributed_init_method: str, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.vision_language_config = vision_language_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + assert self.device_config.device_type == "tpu" + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + self.model_runner = TPUModelRunner(model_config, parallel_config, + scheduler_config, device_config, + cache_config, load_config, + vision_language_config) + + def init_device(self) -> None: + os.environ["PJRT_DEVICE"] = "TPU" + self.device = xm.xla_device() + self.device_config.device = self.device + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # NOTE(woosuk): This is just a hack to initialize the TP group. + # This cannot perform the actual communication ops. + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): This does not completely eliminate the recompilation + # overhead because dynamo does not cache the compiled results. + xr.initialize_cache(os.path.expanduser(envs.VLLM_XLA_CACHE_PATH), + readonly=False) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_layers = self.model_config.get_num_layers(self.parallel_config) + head_size = self.model_config.get_head_size() + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + + kv_caches = [(None, None) for _ in range(num_layers)] + self.model_runner._dummy_run( + batch_size=1, + seq_len=self.scheduler_config.max_num_batched_tokens, + kv_caches=kv_caches, + is_prompt=True, + ) + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + m = xm.get_memory_info(self.device) + program_size = 1024 * 1024 * 1024 # 1GB + free_bytes = max(m["bytes_limit"] - m["bytes_used"] - program_size, 0) + kv_cache_bytes = int(free_bytes * + self.cache_config.gpu_memory_utilization) + kv_cache_dtype_btyes = get_dtype_size(self.cache_dtype) + block_size = self.cache_config.block_size + num_tpu_blocks = (kv_cache_bytes // + (kv_cache_dtype_btyes * block_size * num_layers * 2 * + head_size * num_kv_heads)) + num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. + return num_tpu_blocks, 0 + + def initialize_cache( + self, + num_gpu_blocks: int, + num_cpu_blocks: int, + ) -> None: + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + self.block_size = self.cache_config.block_size + + dtype = self.cache_dtype + num_layers = self.model_config.get_num_layers(self.parallel_config) + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + + self.tpu_cache = [] + tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( + num_gpu_blocks, self.block_size, num_kv_heads, head_size) + for _ in range(num_layers): + key_cache = torch.zeros(tpu_cache_shape, + dtype=dtype, + device=self.device) + value_cache = torch.zeros_like(key_cache) + self.tpu_cache.append((key_cache, value_cache)) + self._warmup_model() + + def _warmup_model(self) -> None: + # FIXME(woosuk): Here we are abusing `enforce_eager` which is defined + # for CUDA graphs. We should refactor this part. + if not self.model_config.enforce_eager: + # Warm up the model with all possible input shapes so that + # compilation never happens during the actual execution. + # This may take ~30 mins for the first run and ~20 mins for the + # subsequent runs. + # If `enforce_eager` is True, the ahead-of-time compilation is + # skipped and the compilation happens during the actual execution, + # which is bad for performance but useful for development. + self.model_runner.warmup_model(self.tpu_cache) + + def get_cache_block_size_bytes(self) -> int: + head_size = self.model_config.get_head_size() + num_heads = self.model_config.get_num_kv_heads(self.parallel_config) + num_layers = self.model_config.get_num_layers(self.parallel_config) + + key_cache_block = self.cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + dtype_size = get_dtype_size(self.cache_dtype) + return dtype_size * total + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + if execute_model_req is None: + return [] + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + num_seq_groups = len(seq_group_metadata_list) + if num_seq_groups == 0: + return [] + + # Currently, TPUWorker does not support swapping. + # TODO(woosuk): Support block copying. + assert len(execute_model_req.blocks_to_swap_in) == 0, ( + "Swapping is not supported for the TPU backend.") + assert len(execute_model_req.blocks_to_swap_out) == 0, ( + "Swapping is not supported for the TPU backend.") + assert len(execute_model_req.blocks_to_copy) == 0 + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.tpu_cache) + return [output]