diff --git a/vllm/envs.py b/vllm/envs.py index 9d585bf3578e..f99c7de79183 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -81,6 +81,7 @@ VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True + VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -581,6 +582,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_ROCM_USE_AITER_MLA": lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in ("true", "1")), + + # Whether to use aiter mha ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_MHA": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in + ("true", "1")), + # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index e8abd32ff6ba..d429b2eb3b3c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op def is_rocm_aiter_rmsnorm_enabled() -> bool: @@ -42,46 +43,71 @@ def fused_add_rms_norm( return x, residual -def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> torch.Tensor: +if is_rocm_aiter_rmsnorm_enabled(): - import aiter as rocm_aiter - if x.dim() > 2: - x_original_shape = x.shape - x = x.reshape(-1, x_original_shape[-1]) - x = rocm_aiter.rms_norm(x, weight, variance_epsilon) - return x.reshape(x_original_shape) + def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: - return rocm_aiter.rms_norm(x, weight, variance_epsilon) + import aiter as rocm_aiter + if x.dim() > 2: + x_original_shape = x.shape + x = x.reshape(-1, x_original_shape[-1]) + x = rocm_aiter.rms_norm(x, weight, variance_epsilon) + return x.reshape(x_original_shape) + return rocm_aiter.rms_norm(x, weight, variance_epsilon) -def rocm_aiter_fused_add_rms_norm( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + def rocm_aiter_rms_norm_fake(input: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + return input.clone() - import aiter as rocm_aiter + direct_register_custom_op( + op_name="rocm_aiter_rms_norm", + op_func=rocm_aiter_rms_norm_impl, + mutates_args=[], + fake_impl=rocm_aiter_rms_norm_fake, + dispatch_key=current_platform.dispatch_key, + ) - residual_out = torch.empty_like(residual) - output = torch.empty_like(x) - rocm_aiter.rmsnorm2d_fwd_with_add( - output, # output - x, # input - residual, # residual input - residual_out, # residual output - weight, - variance_epsilon, + def rocm_aiter_fused_add_rms_norm_impl( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + + import aiter as rocm_aiter + residual_out = torch.empty_like(residual) + output = torch.empty_like(x) + rocm_aiter.rmsnorm2d_fwd_with_add( + output, # output + x, # input + residual, # residual input + residual_out, # residual output + weight, + variance_epsilon, + ) + return output, residual_out + + def rocm_aiter_fused_add_rms_norm_fake( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + return x.clone(), residual.clone() + + direct_register_custom_op( + op_name="rocm_aiter_fused_add_rms_norm", + op_func=rocm_aiter_fused_add_rms_norm_impl, + mutates_args=[], + fake_impl=rocm_aiter_fused_add_rms_norm_fake, + dispatch_key=current_platform.dispatch_key, ) - return output, residual_out def dispatch_cuda_rmsnorm_func(add_residual: bool): if add_residual: if is_rocm_aiter_rmsnorm_enabled(): - return rocm_aiter_fused_add_rms_norm + return torch.ops.vllm.rocm_aiter_fused_add_rms_norm return fused_add_rms_norm if is_rocm_aiter_rmsnorm_enabled(): - return rocm_aiter_rms_norm + return torch.ops.vllm.rocm_aiter_rms_norm return rms_norm diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f3d64f01b0f7..70c2ea27765b 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -181,9 +181,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if envs.VLLM_USE_V1: - logger.info("Using Triton Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "triton_attn.TritonAttentionBackend") + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \ + and on_mi250_mi300(): + logger.info("Using Flash Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "rocm_aiter_fa.AiterFlashAttentionBackend") + else: + logger.info("Using Triton Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "triton_attn.TritonAttentionBackend") if selected_backend == _Backend.ROCM_FLASH: if not cls.has_device_capability(90): # not Instinct series GPUs. diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py new file mode 100644 index 000000000000..a84bd984f655 --- /dev/null +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -0,0 +1,633 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with FlashAttention.""" +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.v1.attention.backends.flash_attn import ( + make_local_attention_virtual_batches) +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +if current_platform.is_rocm(): + import aiter + + from vllm.triton_utils import tl, triton + from vllm.utils import direct_register_custom_op + + @triton.jit + def _vllm_layout_trans_kernel( + k_buffer_ptr, + v_buffer_ptr, + k_values_ptr, + v_values_ptr, + b_seq_lens_loc, + block_table, + block_table_stride_0, + X: tl.constexpr, + H_KV: tl.constexpr, + D: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + + tl.arange(0, 2)) + batch_token_start, batch_token_end = tl.split(batch_token_indexes) + seq_len = batch_token_end - batch_token_start + + DIM0: tl.constexpr = H_KV * D // X + DIM1: tl.constexpr = X * BLOCK_SIZE + E_DIM: tl.constexpr = H_KV * D + if block_idx * BLOCK_SIZE < seq_len: + # print("block_idx", block_idx) + k_block_mask = (block_idx * BLOCK_SIZE + + tl.arange(0, BLOCK_SIZE)[None, :, None]) < seq_len + v_block_mask = (block_idx * BLOCK_SIZE + + tl.arange(0, BLOCK_SIZE)[None, :]) < seq_len + + kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 + + block_idx) + + k_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange( + 0, DIM0)[:, None, None] * DIM1 + tl.arange( + 0, BLOCK_SIZE)[None, :, None] * X + tl.arange( + 0, X)[None, None, :] + v_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange( + 0, E_DIM)[:, None] * BLOCK_SIZE + tl.arange( + 0, BLOCK_SIZE)[None, :] + k_vals = tl.load(k_buffer_ptr + k_buffer_off, + mask=k_block_mask, + other=0.0) + v_vals = tl.load(v_buffer_ptr + v_buffer_off, + mask=v_block_mask, + other=0.0) + k_vals = k_vals.trans(0, 2, 1).view(E_DIM, BLOCK_SIZE) + block_mask = (block_idx * BLOCK_SIZE + + tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len + + kv_values_off = batch_token_start * E_DIM + \ + block_idx * BLOCK_SIZE * E_DIM + tl.arange( + 0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :] + tl.store(k_values_ptr + kv_values_off, k_vals.T, mask=block_mask) + tl.store(v_values_ptr + kv_values_off, v_vals.T, mask=block_mask) + + def vllm_layout_trans(b_seq_lens_loc, block_table, k_buffer, v_buffer, + max_seqlen, total_tokens): + H_KV = v_buffer.shape[1] + D = v_buffer.shape[2] + BLOCK_SIZE = v_buffer.shape[3] + X = k_buffer.shape[-1] + dtype = k_buffer.dtype + + k_values = torch.empty((total_tokens, H_KV, D), + dtype=dtype, + device="cuda") + v_values = torch.empty((total_tokens, H_KV, D), + dtype=dtype, + device="cuda") + + grid = (block_table.shape[0], + (max_seqlen + BLOCK_SIZE - 1) // BLOCK_SIZE) + + _vllm_layout_trans_kernel[grid]( + k_buffer, + v_buffer, + k_values, + v_values, + b_seq_lens_loc, + block_table, + block_table.stride(0), + X=X, + H_KV=H_KV, + D=D, + BLOCK_SIZE=BLOCK_SIZE, + num_stages=1, + num_warps=4, + ) + + return k_values, v_values + + def flash_attn_varlen_func_impl( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + out: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_tokens: int, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + window_size: Optional[list[int]], # -1 means infinite context window + alibi_slopes: Optional[list[float]], + block_table: torch.Tensor, + ) -> torch.Tensor: + k, v = vllm_layout_trans(cu_seqlens_k, block_table, k_cache, v_cache, + max_seqlen_k, total_tokens) + output = aiter.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + window_size=window_size, + out=out, + ) + return output + + def flash_attn_varlen_func_fake( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + out: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_tokens: int, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + window_size: Optional[list[int]], # -1 means infinite context window + alibi_slopes: Optional[list[float]], + block_table: torch.Tensor, + ) -> torch.Tensor: + return torch.empty(q.shape[0], + q.shape[1], + v_cache.shape[-2], + dtype=torch.float8_e4m3fnuz, + device="cuda") + + direct_register_custom_op("flash_attn_varlen_func", + flash_attn_varlen_func_impl, ["out"], + flash_attn_varlen_func_fake, + dispatch_key=current_platform.dispatch_key) + +logger = init_logger(__name__) + + +class AiterFlashAttentionMetadataBuilder: + + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, + block_table: BlockTable): + model_config = runner.model_config + + self.runner = runner + self.num_heads_q = model_config.get_num_attention_heads( + runner.parallel_config) + self.num_heads_kv = model_config.get_num_kv_heads( + runner.parallel_config) + self.headdim = model_config.get_head_size() + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table + + # Sliding window size to be used with the AOT scheduler will be + # populated on first build() call. + self.aot_sliding_window: Optional[tuple[int, int]] = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return False + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): + max_seq_len = self.runner.seq_lens_np[:num_reqs].max() + total_tokens = self.runner.seq_lens_np[:num_reqs].sum() + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + + slot_mapping = block_table.slot_mapping[:num_actual_tokens] + + cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, + dtype=torch.int32, + device="cuda") + torch.cumsum(seq_lens, + dim=0, + dtype=cu_seq_lens.dtype, + out=cu_seq_lens[1:]) + + def schedule(batch_size, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): + return None + + # for local attention + local_attn_metadata = None + if self.runner.attention_chunk_size is not None: + seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ + virt_block_table_tensor = make_local_attention_virtual_batches( + self.runner.attention_chunk_size, + self.runner.query_start_loc_np[:num_reqs + 1], + self.runner.seq_lens_np[:num_reqs], + block_table_tensor, + self.block_size, + ) + local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( + self.runner.device, non_blocking=True) + local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True) + local_max_query_len = seqlens_q_local_np.max() + local_max_seq_len = virt_k_seqlens_np.max() + local_scheduler_metadata = schedule( + batch_size=local_query_start_loc.shape[0] - 1, + cu_query_lens=local_query_start_loc, + max_query_len=local_max_query_len, + seqlens=local_seqused_k, + max_seq_len=local_max_seq_len, + causal=True) + + local_attn_metadata = \ + AiterFlashAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=local_query_start_loc, + local_seqused_k=local_seqused_k, + local_block_table=virt_block_table_tensor, + local_max_query_len=local_max_query_len, + local_max_seq_len=local_max_seq_len, + local_scheduler_metadata=local_scheduler_metadata, + ) + + use_cascade = common_prefix_len > 0 + + if use_cascade: + cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device=self.runner.device) + prefix_kv_lens = torch.tensor([common_prefix_len], + dtype=torch.int32, + device=self.runner.device) + suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - + common_prefix_len) + suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( + self.runner.device) + prefix_scheduler_metadata = schedule( + batch_size=1, + cu_query_lens=cu_prefix_query_lens, + max_query_len=num_actual_tokens, + seqlens=prefix_kv_lens, + max_seq_len=common_prefix_len, + causal=False) + scheduler_metadata = schedule(batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - + common_prefix_len, + causal=True) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + scheduler_metadata = schedule(batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=True) + + attn_metadata = AiterFlashAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + cu_seq_lens=cu_seq_lens, + total_tokens=total_tokens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + scheduler_metadata=scheduler_metadata, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, + ) + return attn_metadata + + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + + +class AiterFlashAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["AiterFlashAttentionImpl"]: + return AiterFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AiterFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]: + return AiterFlashAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + +@dataclass +class AiterFlashAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + cu_seq_lens: torch.Tensor + total_tokens: int + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None + + # for local attention + @dataclass + class LocalAttentionMetadata: + local_query_start_loc: torch.Tensor + local_seqused_k: torch.Tensor + local_block_table: torch.Tensor + local_max_query_len: int + local_max_seq_len: int + local_scheduler_metadata: Optional[torch.Tensor] + + local_attn_metadata: Optional[LocalAttentionMetadata] = None + + +class AiterFlashAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + use_irope: bool = False, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = \ + AiterFlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}. " + "Set VLLM_USE_V1=0 to use another attention backend.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl") + self.use_irope = use_irope + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashAttention does not support fp8 kv-cache on this device.") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AiterFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + # Reshape the input keys and values and store them in the cache. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] and + # value[:num_actual_tokens] because the reshape_and_cache_flash op uses + # the slot_mapping's shape to determine the number of actual tokens. + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + + # Compute attention and update output up to `num_actual_tokens`. + use_local_attn = \ + (self.use_irope and attn_metadata.local_attn_metadata is not None) + + if not attn_metadata.use_cascade or use_local_attn: + if use_local_attn: + assert attn_metadata.local_attn_metadata is not None + local_metadata = attn_metadata.local_attn_metadata + cu_seqlens_q = local_metadata.local_query_start_loc + seqused_k = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len + block_table = local_metadata.local_block_table + else: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + cu_seq_lens = attn_metadata.cu_seq_lens + total_tokens = attn_metadata.total_tokens + if max_seqlen_q <= 1: + block_size = value_cache.shape[3] + num_query_heads = query.shape[1] + num_kv_heads = key.shape[1] + head_size = query.shape[2] + + _PARTITION_SIZE_ROCM = 256 + max_num_partitions = ( + (max_seqlen_k + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + total_num_seq = query.shape[0] + tmp_output = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions, + head_size), + dtype=query.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale=self.scale, + block_tables=block_table, + seq_lens=seqused_k, + query_start_loc=cu_seqlens_q, + block_size=block_size, + max_seq_len=max_seqlen_k, + alibi_slopes=self.alibi_slopes, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + ) + else: + torch.ops.vllm.flash_attn_varlen_func( + query[:num_actual_tokens], + key_cache, + value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + total_tokens=total_tokens, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + window_size=list(self.sliding_window), + block_table=block_table, + cu_seqlens_k=cu_seq_lens, + ) + return output + else: + raise NotImplementedError( + "Cascade attention is not implemented for ROCM AITER")