Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
f782c66
add AITER MLA implementation in attention backend
vllmellm Mar 28, 2025
42d5c62
remove unused arguments in aiter mla decode fwd kernel
vllmellm Mar 28, 2025
565a3fd
add unittest for AITER MLA backend in attention selector
vllmellm Mar 29, 2025
645f400
add unittest for MLA attention backend selector
vllmellm Apr 1, 2025
22c8726
code cleaning
vllmellm Apr 1, 2025
5dc1348
update AITER version
vllmellm Apr 1, 2025
12f8023
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 1, 2025
da8c69f
add ck flash attn in prefill mla computation
vllmellm Apr 2, 2025
1ea5718
further code cleaning
vllmellm Apr 2, 2025
681d777
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 2, 2025
9ada055
fix mypy typing errors
vllmellm Apr 3, 2025
1ceb3b9
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 3, 2025
20a3f07
fix mypy error on Iterable typing error
vllmellm Apr 3, 2025
194a42a
remove padding for v tensor in AITER MLA which improves performance
vllmellm Apr 15, 2025
a9a02d5
upgrade aiter package version
vllmellm Apr 15, 2025
02a4fb3
only support AITER FA in AITER MLA backend to avoid latency caused by…
vllmellm Apr 15, 2025
95213e2
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 15, 2025
6e48433
add missing data types of arguments in aiter_mla_decode_fwd
vllmellm Apr 16, 2025
0265f20
support AITER MLA backend on V1
vllmellm Apr 16, 2025
693c870
uncomment the required packages in common.txt
vllmellm Apr 16, 2025
a5a1a54
bugfix in building decode metadata for AITER MLA decode forward pass
vllmellm Apr 23, 2025
38c67c7
optimize the AITER decode metadata build
vllmellm Apr 24, 2025
74c9cb3
Merge remote-tracking branch 'origin/main' into aiter-mla-v1
vllmellm Apr 24, 2025
643d07f
bugfix caused by merging with main
vllmellm Apr 24, 2025
6171e50
Handle v1 AITER MLA backend in rocm platform
vllmellm Apr 24, 2025
905cec9
update AITER MLA decode metadata build
vllmellm May 1, 2025
20e769e
update AITER commit
vllmellm May 1, 2025
455bbf2
Merge remote-tracking branch 'origin/main' into aiter-mla-v1
vllmellm May 1, 2025
90daf6e
update proper logging info in selected backend as well as updating at…
vllmellm May 1, 2025
f68e926
fix wrong sync merge to main
vllmellm May 1, 2025
821f475
fix pre-commit
vllmellm May 1, 2025
1a6ba99
Merge remote-tracking branch 'origin/main' into aiter-mla-v1
vllmellm May 1, 2025
7cc28d2
add the missing line in common.py
vllmellm May 1, 2025
29fc060
fix wrong logger info message
vllmellm May 2, 2025
7f1ed77
clean code and fix AITER block scaled kernel fake impl in v1 engine
vllmellm May 5, 2025
11a8985
use env variable to adjsut timeout for model execution
vllmellm May 5, 2025
825f387
remove unnecessary backend check
vllmellm May 5, 2025
cbeb0df
make model execution timeout in envs variable rocm specific variable …
vllmellm May 5, 2025
2218bbc
fix unit-test
vllmellm May 6, 2025
cb98504
Merge remote-tracking branch 'origin/main' into aiter-mla-v1
vllmellm May 6, 2025
44d813f
bugfix to update AITER MLA V1 decode forward after sync with main
vllmellm May 6, 2025
423c0be
Update vllm/platforms/rocm.py
vllmellm May 6, 2025
58d79bd
address PR comments
vllmellm May 6, 2025
95644ea
update assertion message
vllmellm May 7, 2025
f41d616
remove env variable for model execution timeout
vllmellm May 8, 2025
56d2254
Merge remote-tracking branch 'origin/main' into aiter-mla-v1
vllmellm May 8, 2025
3ee787e
remove unnecessary warning
vllmellm May 8, 2025
f688418
keep model execution timeout as original value in main branch
vllmellm May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="7e1ed08"
ARG AITER_BRANCH="5a77249"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base
Expand Down
5 changes: 4 additions & 1 deletion tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def test_env(
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == name
if use_v1 and name != "TRITON_MLA":
assert backend.get_name() == f"{name}_VLLM_V1"
else:
assert backend.get_name() == name
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
Expand Down
6 changes: 4 additions & 2 deletions tests/kernels/attention/test_rocm_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True)
assert backend.get_name() == "ROCM_AITER_MLA"
assert (backend.get_name() == "ROCM_AITER_MLA"
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")

# If attention backend is None
# If use_mla is true
Expand All @@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True)
assert backend.get_name() == "ROCM_AITER_MLA"
assert (backend.get_name() == "ROCM_AITER_MLA"
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
46 changes: 46 additions & 0 deletions vllm/attention/ops/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import torch

from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
max_block_per_batch: int,
Expand All @@ -30,6 +33,28 @@ def aiter_mla_decode_fwd(
kv_last_page_lens: Optional[torch.Tensor] = None,
logit_cap: float = 0.0,
):

torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
kv_buffer.view(
-1, 1, 1, q.shape[-1]),
o,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap)


def mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
from aiter.mla import mla_decode_fwd

mla_decode_fwd(q,
Expand All @@ -40,3 +65,24 @@ def aiter_mla_decode_fwd(
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap)


def mla_decode_fwd_fake(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
pass


if current_platform.is_rocm():
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
op_func=mla_decode_fwd_impl,
mutates_args=["o"],
fake_impl=mla_decode_fwd_fake,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious is the fake_impl necessary here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad Yes without it there is error from torch dynamo while building cuda graphs.

tags=[torch.Tag.needs_fixed_stride_order])
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
"ROCM_AITER_MLA",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
block_shape: List[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:

return torch.empty_like(a1, dtype=torch.bf16)
return torch.empty_like(a1, dtype=hidden_states_dtype)


def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class _Backend(enum.Enum):
TRITON_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_MLA = enum.auto() # Supported by V1
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
Expand Down
11 changes: 8 additions & 3 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
elif selected_backend == _Backend.ROCM_AITER_MLA:
elif selected_backend == _Backend.ROCM_AITER_MLA \
or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1:
if block_size == 1:
logger.info("Using AITER MLA backend.")
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
if use_v1:
logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
else:
logger.info("Using AITER MLA backend")
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
Expand Down
11 changes: 6 additions & 5 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)

# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)
if self.aot_schedule:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this a bit? Why was this change necessary?

Copy link
Contributor Author

@vllmellm vllmellm May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore the self.page_size if only defined in __init__ with the condition self.aot_schedule while on ROCm this condition is not true and it encounters the error self.page_size is not defined.

self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
# Dont try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.runner.block_size

You may want to ask the author about this as these line changes were added in this PR.

anyways if self.page_size is defined without this self.aot_schedule condition it does not have any effect on ROCm at least for AITER MLA which is the only MLA backend in V1 currently.

# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)

assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
Expand Down
196 changes: 196 additions & 0 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Any, Optional

import torch

import vllm.envs as envs
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)

# yapf: enable


def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA


class AiterMLABackend(MLACommonBackend):

@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA_VLLM_V1"

@staticmethod
def get_impl_cls() -> type["AiterMLAImpl"]:
return AiterMLAImpl

@staticmethod
def get_metadata_cls() -> type["AiterMLAMetadata"]:
return AiterMLAMetadata

@staticmethod
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder


@dataclass
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None


class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass


class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):

def __init__(self, runner):
super().__init__(runner)
max_model_len = self.runner.model_config.max_model_len
assert max_model_len == 32768,\
"AITER MLA requires max_model_len=32768"
assert self.runner.block_size == 1, "AITER MLA" \
"only supports block size 1."

def _get_paged_kv_tensors(
self, block_table: torch.Tensor,
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
page_size = self.runner.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size

mask = (torch.arange(block_table.size(1),
dtype=block_table.dtype,
device=block_table.device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask]

paged_kv_indptr = torch.cat([
torch.zeros(1,
dtype=block_table_bounds.dtype,
device=block_table_bounds.device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])

paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
return (
paged_kv_indices,
paged_kv_indptr,
paged_kv_last_page_len,
)

def _build_decode(self, input_positions: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:

(
paged_kv_indices,
paged_kv_indptr,
paged_last_page_len,
) = self._get_paged_kv_tensors(block_table, seq_lens)

attn_metadata = AiterMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_last_page_len)

return attn_metadata


class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)

unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")

from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func

def _flash_attn_varlen_diff_headdims(self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore you may want to check coomon.py the method _flash_attn_varlen_diff_headdims is defined there and overridden in this class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now. I must have mistyped the string when I searched for it :).

q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)

return output

def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

B = q_nope.shape[0]

q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)

kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)

aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)

return self._v_up_proj(o)