diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c47202099ac6..f12101e504eb 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -13,11 +13,19 @@ AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, CommonMetadataBuilder) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) from vllm.logger import init_logger from vllm.platforms import current_platform +USE_AITER_PAGED_ATTN = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN + +if USE_AITER_PAGED_ATTN: + from vllm.attention.ops.paged_attn import PagedAttentionMetadata + from vllm.attention.ops.rocm_aiter_paged_attn import ( + AiterPagedAttention as PagedAttention) +else: + from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) + if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata @@ -469,7 +477,15 @@ def __init__( if blocksparse_params is not None: raise ValueError( "ROCmFlashAttention does not support blocksparse attention.") - + self.aiter_kv_scales_initialized = False + + if USE_AITER_PAGED_ATTN and kv_cache_dtype not in [ + "int8", "fp8", "fp8_e4m3" + ]: + logger.warning("ROCM AITER paged attention does not " + "support non-8-bit kv_cache data types. " + "kv_cache_dtype: {kv_cache_dtype}. " + "Falling back to Triton PagedAttention") if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. self.logits_soft_cap = 0.0 @@ -616,6 +632,31 @@ def forward( else: assert value is None + # Reshaping kv tensors is required for AITER paged attention kernel + # because it works on a different tensor shape, + # when the size of one element is one byte (int8/fp8 dtypes). + # This reshaping is only required on the first forward call + # and the kv cache must not be empty. + if (USE_AITER_PAGED_ATTN \ + and kv_cache.dtype.itemsize == 1 \ + and not self.aiter_kv_scales_initialized \ + and kv_cache.shape != torch.Size([0])): + + num_blocks = kv_cache.shape[1] + block_size = kv_cache.shape[2] // (self.num_kv_heads * + self.head_size) + k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + self.aiter_kv_scales_initialized = True + k_scale.fill_(layer._k_scale.item()) + v_scale.fill_(layer._v_scale.item()) + layer._k_scale = k_scale + layer._v_scale = v_scale + # Only update KV cache for decoder self-attention # and encoder-decoder cross-attention if self.attn_type not in [ @@ -909,4 +950,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 + and not USE_AITER_PAGED_ATTN) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py new file mode 100644 index 000000000000..b8e5d41ae95a --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import aiter as rocm_aiter +import torch + +from vllm.attention.ops.paged_attn import PagedAttention + + +class AiterPagedAttention(PagedAttention): + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + else: + if "fp8" in kv_cache_dtype: + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + else: + key_cache = key_cache.view(torch.int8) + value_cache = value_cache.view(torch.int8) + rocm_aiter.reshape_and_cache_with_pertoken_quant( + key, value, key_cache, value_cache, k_scale, v_scale, + slot_mapping.flatten(), True) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + return PagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + kv_cache_dtype=kv_cache_dtype, + num_kv_heads=num_kv_heads, + scale=scale, + alibi_slopes=alibi_slopes, + k_scale=k_scale, + v_scale=v_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step) + + if "fp8" in kv_cache_dtype: + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + + rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, + seq_lens, max_num_blocks_per_seq, k_scale, + v_scale, output) + return output diff --git a/vllm/envs.py b/vllm/envs.py index a36d20a4f8b5..7b5819248fc1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,6 +75,8 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = False + VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 @@ -523,6 +525,19 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), + # use aiter ops unless specifically disabled. + # Acts as a parent switch to enable the rest of the other operations. + "VLLM_ROCM_USE_AITER": + lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in + ("true", "1")), + + # use aiter paged attention if aiter ops are enabled. + # this is disabled by default. + "VLLM_ROCM_USE_AITER_PAGED_ATTN": + lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in + ("true", "1") and os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", + "False").lower() in ("true", "1")), + # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),