diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index b1d4461d164a..d9ded25ee916 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -61,12 +61,14 @@ def run_test(more_args): ) measured_value = results["results"][TASK][FILTER] + print(f"{measured_value=}") assert (measured_value - RTOL < EXPECTED_VALUE and measured_value + RTOL > EXPECTED_VALUE ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not current_platform.is_cuda(), +@pytest.mark.skipif(not current_platform.is_cuda() + and not current_platform.is_tpu(), reason="V1 currently only supported on CUDA") def test_lm_eval_accuracy_v1_engine(monkeypatch): """Run with the V1 Engine.""" diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 664707e9dc65..9919c31ef8ab 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -25,6 +25,7 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() + PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() NO_ATTENTION = enum.auto() @@ -140,6 +141,10 @@ def _cached_get_attn_backend( from vllm.v1.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend as FlashAttentionBackendV1) return FlashAttentionBackendV1 + if backend == _Backend.PALLAS_VLLM_V1: + from vllm.v1.attention.backends.pallas import ( # noqa: F401 + PallasAttentionBackend as PallasAttentionBackendV1) + return PallasAttentionBackendV1 if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -232,8 +237,11 @@ def which_attn_to_use(head_size: int, return _Backend.IPEX if current_platform.is_tpu(): - if selected_backend != _Backend.PALLAS: + if (selected_backend != _Backend.PALLAS + and selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) + if use_v1: + return _Backend.PALLAS_VLLM_V1 return _Backend.PALLAS if current_platform.is_rocm(): diff --git a/vllm/config.py b/vllm/config.py index b354fb61d7b7..6760bcbc24c0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1225,6 +1225,8 @@ def __init__(self, device: str = "auto") -> None: # Some device types require processing inputs on CPU if self.device_type in ["neuron", "openvino"]: self.device = torch.device("cpu") + # Device initialization should happen after initializing the + # distributed runtime. elif self.device_type in ["tpu"]: self.device = None else: diff --git a/vllm/utils.py b/vllm/utils.py index 1b02cbff79f7..4a548496a478 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -728,6 +728,9 @@ def is_pin_memory_available() -> bool: elif current_platform.is_hpu(): print_warning_once("Pin memory is not supported on HPU.") return False + elif current_platform.is_tpu(): + print_warning_once("Pin memory is not supported on TPU.") + return False elif current_platform.is_cpu() or current_platform.is_openvino(): return False return True diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e73a1e60b273..057f281c220a 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -8,7 +8,6 @@ AttentionMetadata, AttentionType) from vllm.forward_context import get_forward_context from vllm.utils import direct_register_custom_op -from vllm.vllm_flash_attn import flash_attn_varlen_func class FlashAttentionBackend(AttentionBackend): @@ -202,6 +201,8 @@ def unified_v1_flash_attention( v_scale, ) + from vllm.vllm_flash_attn import flash_attn_varlen_func + attn_output = flash_attn_varlen_func( q=query[:num_actual_tokens], k=key_cache, diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py new file mode 100644 index 000000000000..b2cdc06ee78c --- /dev/null +++ b/vllm/v1/attention/backends/pallas.py @@ -0,0 +1,298 @@ +"""Attention layer with FlashAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionType) + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [128] + + @staticmethod + def get_name() -> str: + return "pallas-vllm-v1" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionImpl"]: + return PallasAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasAttentionMetadata"]: + return PallasAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (num_kv_heads, num_blocks, block_size, head_size) + + +@dataclass +class PallasAttentionMetadata: + + is_prompt: bool + slot_mapping: torch.Tensor + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + + +class PallasAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + if logits_soft_cap is not None: + raise NotImplementedError( + "Attention logits soft-capping is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = PallasAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PallasAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in PallasAttentionImpl.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionImpl") + + # Unpack + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + # Write to KV cache. + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache = kv_cache[0] + value_cache = kv_cache[1] + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.is_prompt: + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention requires [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + # NOTE(woosuk): A temporary workaround to avoid the error: + # "xla::paged_attention() Expected a value of type 'str' for + # argument 'megacore_mode' but instead found type 'NoneType'." + if megacore_mode is not None: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + ) + else: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + ) + return output diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index ee860e792281..98eeb07ba5cf 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -22,6 +22,10 @@ def __init__( cache_config: CacheConfig, lora_config: Optional[LoRAConfig], ) -> None: + # TODO: properly handle for TPU. + cache_config.enable_prefix_caching = False + scheduler_config.chunked_prefill_enabled = False + self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config @@ -147,6 +151,12 @@ def schedule(self) -> "SchedulerOutput": num_computed_tokens -= 1 num_new_tokens = 1 computed_blocks.pop() + + # If chunked prefill is not enabled, breakout of the loop. + if (not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget): + break + num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 new_blocks = self.kv_cache_manager.allocate_slots( diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 2d7c58cfea13..080fa46c6d54 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -20,6 +21,7 @@ from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.executor.tpu_executor import TPUExecutor logger = init_logger(__name__) @@ -29,7 +31,7 @@ class AsyncLLM(EngineClient): def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Union[GPUExecutor, TPUExecutor]], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, @@ -120,6 +122,8 @@ def shutdown(self): @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig): + if current_platform.is_tpu: + return TPUExecutor return GPUExecutor async def add_request( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f9d3473d0131..95a910b3e1c4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -18,6 +18,7 @@ from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.executor.tpu_executor import TPUExecutor from vllm.v1.request import Request, RequestStatus from vllm.version import __version__ as VLLM_VERSION @@ -34,17 +35,19 @@ class EngineCore: def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Union[GPUExecutor, TPUExecutor]], usage_context: UsageContext, ): # Override the configs for V1. # FIXME if usage_context == UsageContext.LLM_CLASS: - vllm_config.scheduler_config.max_num_seqs = 1024 - vllm_config.scheduler_config.max_num_batched_tokens = 8192 + # vllm_config.scheduler_config.max_num_seqs = 1024 + # vllm_config.scheduler_config.max_num_batched_tokens = 8192 + pass elif usage_context == UsageContext.OPENAI_API_SERVER: - vllm_config.scheduler_config.max_num_seqs = 1024 - vllm_config.scheduler_config.max_num_batched_tokens = 2048 + # vllm_config.scheduler_config.max_num_seqs = 1024 + # vllm_config.scheduler_config.max_num_batched_tokens = 2048 + pass # TODO (ywang96): Enable APC by default when VLM supports it. if not vllm_config.model_config.is_multimodal_model: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index f37db92e8ea6..4d419b8f97bf 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -17,6 +18,7 @@ from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.executor.tpu_executor import TPUExecutor logger = init_logger(__name__) @@ -27,7 +29,7 @@ class LLMEngine: def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Union[GPUExecutor, TPUExecutor]], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, @@ -92,6 +94,8 @@ def from_engine_args( @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig): + if current_platform.is_tpu(): + return TPUExecutor return GPUExecutor def stop_remote_worker_execution_loop(self) -> None: diff --git a/vllm/v1/executor/tpu_executor.py b/vllm/v1/executor/tpu_executor.py new file mode 100644 index 000000000000..5e6e63086946 --- /dev/null +++ b/vllm/v1/executor/tpu_executor.py @@ -0,0 +1,80 @@ +from typing import Optional, Tuple + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_worker import TPUWorker + +logger = init_logger(__name__) + +# import torch_xla.debug.profiler as xp + + +class TPUExecutor: + + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.worker = self._create_worker() + self.worker.initialize() + self.worker.load_model() + + # self.server = xp.start_server(9012) + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> TPUWorker: + """Return worker init args for a given rank.""" + + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + return TPUWorker( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + ) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.worker.determine_num_available_blocks() + + def initialize_cache(self, num_tpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# TPU blocks: %d", num_tpu_blocks) + self.worker.initialize_cache(num_tpu_blocks) + self.worker.compile_or_warm_up_model() + + def execute_model( + self, + scheduler_output, + ) -> ModelRunnerOutput: + # xp.trace_detached('localhost:9012', "./profiles") + output = self.worker.execute_model(scheduler_output) + return output + + def check_health(self) -> None: + # TPUExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py new file mode 100644 index 000000000000..7963fe4973b5 --- /dev/null +++ b/vllm/v1/worker/tpu_model_runner.py @@ -0,0 +1,981 @@ +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +import torch_xla.core.xla_model as xm + +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import MultiModalDataDict +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_pin_memory_available +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, + PallasAttentionMetadata) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME: Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 + + +@dataclass +class PrefillInputData: + + request_ids: List + prompt_lens: List + token_ids: List + position_ids: List + attn_metadata: List + + def zipped(self): + return zip(self.request_ids, self.prompt_lens, self.token_ids, + self.position_ids, self.attn_metadata) + + +@dataclass +class DecodeInputData: + + num_decodes: int + token_ids: Optional[torch.Tensor] = None + position_ids: Optional[torch.Tensor] = None + attn_metadata: PallasAttentionMetadata = None + + +class TPUModelRunner: + + def __init__( + self, + vllm_config: VllmConfig, + ): + # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + cache_config.cache_dtype] + + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) + self.max_num_tokens = scheduler_config.max_num_batched_tokens + + # Model-related. + self.num_attn_layers = model_config.get_num_attention_layers( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() + + # List[k_cache, v_cache] + self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + + # Request states. + self.requests: Dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.scheduler_config.max_num_seqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + device=self.device, + pin_memory=self.pin_memory, + ) + + self.prefill_positions = torch.tensor( + range(self.max_model_len), + device="cpu", + ).to(torch.int32).reshape(1, -1) + + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + # Remove stopped requests from the cached states. + # Keep the states of the pre-empted requests. + for req_id in scheduler_output.finished_req_ids: + self.requests.pop(req_id, None) + + # Remove the requests from the persistent batch. + stopped_req_ids = set().union( + scheduler_output.preempted_req_ids, + scheduler_output.finished_req_ids, + ) + removed_req_indices: List[int] = [] + for req_id in stopped_req_ids: + req_index = self.input_batch.remove_request(req_id) + if req_index is not None: + removed_req_indices.append(req_index) + + # Update the states of the running requests. + for req_data in scheduler_output.scheduled_running_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + req_index = self.input_batch.req_id_to_index[req_id] + + # Update the num_computed_tokens. + req_state.num_computed_tokens = req_data.num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + req_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(req_data.new_block_ids) + if num_new_blocks == 0: + continue + start_index = len(req_state.block_ids) + end_index = start_index + num_new_blocks + req_state.block_ids.extend(req_data.new_block_ids) + self.input_batch.block_table_cpu[ + req_index, start_index:end_index] = req_data.new_block_ids + + req_ids_to_add: List[str] = [] + # Add new requests to the cached states. + for req_data in scheduler_output.scheduled_new_reqs: + req_id = req_data.req_id + sampling_params = req_data.sampling_params + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + self.requests[req_id] = CachedRequestState( + req_id=req_id, + prompt_token_ids=req_data.prompt_token_ids, + prompt=req_data.prompt, + multi_modal_data=req_data.multi_modal_data, + sampling_params=sampling_params, + generator=generator, + block_ids=req_data.block_ids, + num_computed_tokens=req_data.num_computed_tokens, + output_token_ids=[], + ) + req_ids_to_add.append(req_id) + + # Update the cached states of the resumed requests. + for req_data in scheduler_output.scheduled_resumed_reqs: + req_id = req_data.req_id + req_state = self.requests[req_id] + + req_state.block_ids = req_data.block_ids + req_state.num_computed_tokens = req_data.num_computed_tokens + req_ids_to_add.append(req_id) + + # THIS MOVES ALL THE DECODES TO THE FIRST N IN BATCH. + # Condense the batched states if there are empty indices. + removed_req_indices = sorted(removed_req_indices, reverse=True) + if removed_req_indices: + self.input_batch.condense(removed_req_indices) + + # ALL THE PREFILLS ARE THE LAST M IN THE BATCH. + # These are added at the end after the bacth is condensed. + self.input_batch.num_prefills = len(req_ids_to_add) + for req_id in req_ids_to_add: + req_state = self.requests[req_id] + self.input_batch.add_request(req_state, None) + + def _prepare_prefill_inputs( + self, + num_scheduled_tokens: List[int], + ) -> PrefillInputData: + # Each prefill run separately with shape [1, padded_prompt_len]. + # So we create lists that will be used in execute_model(). + + prefill_request_ids = [] + prefill_prompt_lens = [] + prefill_token_ids = [] + prefill_position_ids = [] + prefill_attn_metadata = [] + + # DECODES are the first num_decodes REQUESTS. + # PREFILLS are the next num_reqs - num_decodes REQUESTS. + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + for idx in range(num_decodes, num_reqs): + prefill_request_ids.append(self.input_batch.req_ids[idx]) + + # STATIC SHAPE: prefills are padded to the next power of 2. + prompt_len = num_scheduled_tokens[idx] + padded_prompt_len = _get_padded_prefill_len(prompt_len) + prefill_prompt_lens.append(prompt_len) + assert padded_prompt_len <= self.max_model_len + + # TOKEN_IDS. + token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[ + idx, :padded_prompt_len].reshape(1, -1)) + prefill_token_ids.append(token_ids.to(self.device)) + + # POSITIONS. + positions = self.prefill_positions[:, :padded_prompt_len] + prefill_position_ids.append(positions.to(self.device)) + + # SLOT_MAPPING. + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_numbers = self.input_batch.block_table_cpu_tensor[ + idx, positions // self.block_size].reshape(1, -1) + block_offsets = positions % self.block_size + slot_mapping = block_numbers * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[:, prompt_len:] = _PAD_SLOT_ID + slot_mapping = slot_mapping.long() + + # ATTN_METADATA. + prefill_attn_metadata.append( + PallasAttentionMetadata( + is_prompt=True, + slot_mapping=slot_mapping.to(self.device), + block_tables=None, + context_lens=None, + )) + + return PrefillInputData( + request_ids=prefill_request_ids, + prompt_lens=prefill_prompt_lens, + token_ids=prefill_token_ids, + position_ids=prefill_position_ids, + attn_metadata=prefill_attn_metadata, + ) + + def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData: + # Decodes run as one single padded batch with shape [batch, 1] + # + # We need to set _PAD_SLOT_ID for the padding tokens in the + # slot_mapping, such that the attention KV cache insertion + # logic knows to ignore those indicies. Otherwise, the + # padding data can be dummy since we have a causal mask. + + if num_decodes == 0: + return DecodeInputData(num_decodes=0) + + # PAD FOR STATIC SHAPES. + padded_batch_size = _get_padded_batch_size(num_decodes) + + # POSITIONS. [batch, 1] + # We slice at the end, since we use the positions for gathering. + positions = torch.from_numpy( + self.input_batch.num_computed_tokens_cpu.reshape(-1, 1)) + index = positions.to(torch.int64) + positions = positions[:padded_batch_size] + + # TOKEN_IDS. [batch, 1] + token_ids = torch.gather( + input=torch.from_numpy(self.input_batch.token_ids_cpu), + dim=1, + index=index, + )[:padded_batch_size] + + # SLOT_MAPPING [batch, 1] + # The "slot" is the "physical index" of a token in the KV cache. + # Look up the block_idx in the block table (logical<>physical map) + # to compute this. + block_number = torch.gather( + input=self.input_batch.block_table_cpu_tensor, + dim=1, + index=(index // self.block_size)) + block_offsets = index % self.block_size + slot_mapping = block_number * self.block_size + block_offsets + # Set an out of range value for the padding tokens so that they + # are ignored when inserting into the KV cache. + slot_mapping[num_decodes:] = _PAD_SLOT_ID + slot_mapping = slot_mapping[:padded_batch_size] + + # BLOCK_TABLE [batch, max_num_blocks_per_req] + block_table = self.input_batch.block_table_cpu_tensor[: + padded_batch_size] + + # CONTEXT_LENS [batch_size] + context_lens = (positions.reshape(-1) + 1) + + # CPU<>TPU sync happens here. + return DecodeInputData(num_decodes=num_decodes, + token_ids=token_ids.to(self.device), + position_ids=positions.to(self.device), + attn_metadata=PallasAttentionMetadata( + is_prompt=False, + slot_mapping=slot_mapping.to(self.device), + block_tables=block_table.to(self.device), + context_lens=context_lens.to(self.device), + )) + + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" + ) -> Tuple[PrefillInputData, Optional[DecodeInputData]]: + + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + + num_reqs = self.input_batch.num_reqs + num_decodes = self.input_batch.num_decodes + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens.append(num_tokens) + + # NOTE: assert that all the decodes are "decodes". + if idx < num_decodes: + assert num_tokens == 1 + + return ( + self._prepare_prefill_inputs(num_scheduled_tokens), + self._prepare_decode_inputs(num_decodes), + ) + + def _prepare_sampling( + self, + scheduler_output: "SchedulerOutput", + ) -> SamplingMetadata: + skip_copy = True + if (scheduler_output.finished_req_ids + or scheduler_output.preempted_req_ids): + skip_copy = False + if (scheduler_output.scheduled_new_reqs + or scheduler_output.scheduled_resumed_reqs): + skip_copy = False + # Create the sampling metadata. + sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) + return sampling_metadata + + @torch.no_grad() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self._update_states(scheduler_output) + prefill_data, decode_data = self._prepare_inputs(scheduler_output) + num_reqs = self.input_batch.num_reqs + sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32) + + ######################### DECODES ######################### + # Decodes run as one single batch with [padded_batch, 1] + if decode_data.num_decodes > 0: + + # FORWARD. + selected_token_ids = self.model(decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + self.kv_caches, + is_prompt=False) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_ids = selected_token_ids.cpu()[:decode_data.num_decodes] + sampled_token_ids_list = token_ids.tolist() + sampled_token_ids[:decode_data.num_decodes] = token_ids + + # UPDATE REQUEST STATE. + for i, req_id in enumerate( + self.input_batch.req_ids[:decode_data.num_decodes]): + req_state = self.requests[req_id] + + # TODO: ASSERT NO CHUNKED PREFILL. + assert scheduler_output.num_scheduled_tokens[req_id] == 1 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len == req_state.num_tokens + + token_id = sampled_token_ids_list[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + ######################### PREFILLS ######################### + # Prefills run separately with shape [1, padded_prefill_len], + # due to lack of variable length attention kernel so far. + for idx, (req_id, prompt_len, token_ids, position_ids, + attn_metadata) in enumerate(prefill_data.zipped()): + + # FORWARD. + selected_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + self.kv_caches, + is_prompt=True) + + # NOTE: TPU<>CPU sync happens here. + # We need to call .cpu() first to avoid recompilation. + token_id = selected_token_ids.cpu()[prompt_len - 1].item() + sampled_token_ids[decode_data.num_decodes + idx] = token_id + req_state = self.requests[req_id] + + # TODO: ASSERT NO PREFIX CACHING. + assert req_state.num_computed_tokens == 0 + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + # TODO: ASSERT NO CHUNKED PREFILL. + assert seq_len == req_state.num_tokens + assert prompt_len == seq_len + + # UPDATE REQUEST STATE. + req_idx = self.input_batch.req_id_to_index[req_id] + self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id + req_state.output_token_ids.append(token_id) + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids[:num_reqs], + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids_cpu=sampled_token_ids, + logprob_token_ids_cpu=None, + logprobs_cpu=None, + ) + + def load_model(self) -> None: + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + + # xm_tp_rank = xr.global_ordinal() + # with patch( + # "vllm.model_executor.layers.vocab_parallel_embedding." + # "get_tensor_model_parallel_rank", + # return_value=xm_tp_rank): + # model = get_model(vllm_config=self.vllm_config) + model = get_model(vllm_config=self.vllm_config) + model = model.eval() + xm.wait_device_ops() + self.model = ModelWrapper(model) + + def _dummy_run(self, batch_size: int, seq_len: int, + kv_caches: List[torch.Tensor], is_prompt: bool) -> None: + """Dummy warmup run for memory usage and graph compilation.""" + + input_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = None if is_prompt else torch.zeros( + (batch_size, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device, + ) + context_lens = None if is_prompt else torch.ones( + (batch_size, ), + dtype=torch.int32, + device=self.device, + ) + attn_metadata = PallasAttentionMetadata( + is_prompt=is_prompt, + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + ) + + # NOTE: There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if is_prompt: + torch._dynamo.mark_dynamic(input_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + + # Dummy run. + self.model(input_ids, + position_ids, + attn_metadata, + kv_caches, + is_prompt=is_prompt) + + def profile_run(self) -> None: + """Profile to measure peak memory during forward pass.""" + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value `None`. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + dummy_kv_caches = [( + torch.tensor([], dtype=torch.float32, device=self.device), + torch.tensor([], dtype=torch.float32, device=self.device), + ) for _ in range(self.num_attn_layers)] + + # Round to multiple of 16. + seq_len = (self.max_num_tokens + 15) // 16 * 16 + + # Run empty forward. + self._dummy_run(batch_size=1, + seq_len=seq_len, + kv_caches=dummy_kv_caches, + is_prompt=True) + + def capture_model(self) -> None: + """Compile the model.""" + + logger.info("Compiling the model with different input shapes.") + + # Prefill shapes. + start = time.perf_counter() + for batch_size in [1]: + seq_len = 16 + while True: + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + is_prompt=True) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + if seq_len >= self.model_config.max_model_len: + break + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + seq_len = seq_len * 2 + + end = time.perf_counter() + logger.info("Compilation for prefill done in %.2f s.", end - start) + + # Decode shapes. + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self._dummy_run(batch_size, + seq_len, + self.kv_caches, + is_prompt=False) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode done in %.2f s.", end - start) + + def initialize_kv_cache(self, num_blocks: int) -> None: + assert len(self.kv_caches) == 0 + kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + for _ in range(self.num_attn_layers): + self.kv_caches.append(( + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), + torch.zeros(kv_cache_shape, + dtype=self.kv_cache_dtype, + device=self.device), + )) + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalDataDict"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.zeros((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.zeros(max_num_reqs, dtype=np.int32) + + # Attention-related. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + self.num_prefills = 0 + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def num_decodes(self) -> int: + return self.num_reqs - self.num_prefills + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 + + +class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): + + def __init__(self, model: nn.Module): + self.model = model + compiled_callable = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_callable) + + def __call__(self, *args, is_prompt: bool, **kwargs): + if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: + # not fully compiled yet, or not using the custom dispatcher, + # let PyTorch handle it + return self.compiled_callable(*args, **kwargs) + # the 3 compiled codes are: + # 0: for profiling + # 1: for prompt + # 2: for decode + # dispatch to the compiled code directly, skip PyTorch + if is_prompt: + with self.dispatch_to_code(1): + return self.forward(*args, **kwargs) + else: + with self.dispatch_to_code(2): + return self.forward(*args, **kwargs) + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: PallasAttentionMetadata, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, None) + + # Greedy sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + return argmax_token_ids.squeeze(dim=1) + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py new file mode 100644 index 000000000000..866c1dbf6ea9 --- /dev/null +++ b/vllm/v1/worker/tpu_worker.py @@ -0,0 +1,198 @@ +"""A TPU worker class.""" + +import os +from typing import TYPE_CHECKING, Tuple + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.tpu_model_runner import TPUModelRunner + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + + +class TPUWorker: + + def __init__(self, vllm_config: VllmConfig, local_rank: int, rank: int, + distributed_init_method: str): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + def initialize(self): + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # NOTE: This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Init ModelRunner here, so that we have access to self.device. + self.model_runner = TPUModelRunner(self.vllm_config) + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + + self.model_runner.profile_run() + + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + m = xm.get_memory_info(self.device) + total_tpu_memory = m["bytes_limit"] + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024) + logger.debug("Total Memory: %sGB", + total_tpu_memory // 1024 // 1024 // 1024) + + cache_block_size = _get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + num_tpu_blocks = int( + (total_tpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_tpu_blocks = (max(num_tpu_blocks, 0) // 8) * 8 + return num_tpu_blocks, 0 + + def initialize_cache(self, num_tpu_blocks: int) -> None: + """Allocate TPU and CPU KV cache with the specified number of blocks.""" + + if num_tpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_tpu_blocks + max_model_len = self.model_config.max_model_len + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.model_runner.initialize_kv_cache(num_tpu_blocks) + + # Get the maximum amount of memory used by the model weights and + # intermediate activations. + xm.mark_step() + xm.wait_device_ops() + m = xm.get_memory_info(self.device) + peak_memory = m[ + "peak_bytes_used"] # Weights + intermediate activations. + logger.debug("Peak GB Used Post KV Cache: %sGB", + peak_memory // 1024 // 1024 // 1024) + + def compile_or_warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model() + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + output = self.model_runner.execute_model(scheduler_output) + # TODO(woosuk): Send the output to the engine process. + return output + + +# TODO: this is a duplicate. +def _get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, +) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total