Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,19 @@
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.platforms import current_platform

USE_AITER_PAGED_ATTN = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN

if USE_AITER_PAGED_ATTN:
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.attention.ops.rocm_aiter_paged_attn import (
AiterPagedAttention as PagedAttention)
else:
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

Expand Down Expand Up @@ -469,7 +477,15 @@ def __init__(
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")

self.aiter_kv_scales_initialized = False

if USE_AITER_PAGED_ATTN and kv_cache_dtype not in [
"int8", "fp8", "fp8_e4m3"
]:
logger.warning("ROCM AITER paged attention does not "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really fallback here? Only see the warning here.

"support non-8-bit kv_cache data types. "
"kv_cache_dtype: {kv_cache_dtype}. "
"Falling back to Triton PagedAttention")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
self.logits_soft_cap = 0.0
Expand Down Expand Up @@ -616,6 +632,31 @@ def forward(
else:
assert value is None

# Reshaping kv tensors is required for AITER paged attention kernel
# because it works on a different tensor shape,
# when the size of one element is one byte (int8/fp8 dtypes).
# This reshaping is only required on the first forward call
# and the kv cache must not be empty.
if (USE_AITER_PAGED_ATTN \
and kv_cache.dtype.itemsize == 1 \
and not self.aiter_kv_scales_initialized \
and kv_cache.shape != torch.Size([0])):

num_blocks = kv_cache.shape[1]
block_size = kv_cache.shape[2] // (self.num_kv_heads *
self.head_size)
k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
dtype=torch.float32,
device=kv_cache.device)
v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
dtype=torch.float32,
device=kv_cache.device)
self.aiter_kv_scales_initialized = True
k_scale.fill_(layer._k_scale.item())
v_scale.fill_(layer._v_scale.item())
layer._k_scale = k_scale
layer._v_scale = v_scale

# Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention
if self.attn_type not in [
Expand Down Expand Up @@ -909,4 +950,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and not USE_AITER_PAGED_ATTN)
98 changes: 98 additions & 0 deletions vllm/attention/ops/rocm_aiter_paged_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional

import aiter as rocm_aiter
import torch

from vllm.attention.ops.paged_attn import PagedAttention


class AiterPagedAttention(PagedAttention):

@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, k_scale,
v_scale)
else:
if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
else:
key_cache = key_cache.view(torch.int8)
value_cache = value_cache.view(torch.int8)
rocm_aiter.reshape_and_cache_with_pertoken_quant(
key, value, key_cache, value_cache, k_scale, v_scale,
slot_mapping.flatten(), True)

@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
return PagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
kv_cache_dtype=kv_cache_dtype,
num_kv_heads=num_kv_heads,
scale=scale,
alibi_slopes=alibi_slopes,
k_scale=k_scale,
v_scale=v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step)

if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)

if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (blocksparse_block_size > 0 and
blocksparse_block_size % block_size == 0), \
(f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables.")

output = torch.empty_like(query)
block_size = value_cache.shape[3]
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size

rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
seq_lens, max_num_blocks_per_seq, k_scale,
v_scale, output)
return output
15 changes: 15 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = False
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
Expand Down Expand Up @@ -523,6 +525,19 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),

# use aiter ops unless specifically disabled.
# Acts as a parent switch to enable the rest of the other operations.
"VLLM_ROCM_USE_AITER":
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),

# use aiter paged attention if aiter ops are enabled.
# this is disabled by default.
"VLLM_ROCM_USE_AITER_PAGED_ATTN":
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1") and os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN",
"False").lower() in ("true", "1")),

# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
Expand Down
Loading