diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 12009b8aa046..222b9c158e5e 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="7e1ed08" +ARG AITER_BRANCH="5a77249" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index b0414244c215..436cb430817e 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -102,7 +102,10 @@ def test_env( block_size, False, use_mla=use_mla) - assert backend.get_name() == name + if use_v1 and name != "TRITON_MLA": + assert backend.get_name() == f"{name}_VLLM_V1" + else: + assert backend.get_name() == name else: with pytest.raises(ValueError) as exc_info: get_attn_backend(16, diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index 4cf7bcb01d4d..6ffe27abf709 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) - assert backend.get_name() == "ROCM_AITER_MLA" + assert (backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") # If attention backend is None # If use_mla is true @@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv("VLLM_ROCM_USE_AITER", "1") backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) - assert backend.get_name() == "ROCM_AITER_MLA" + assert (backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 1c90f8c19b09..3348d18804aa 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -4,6 +4,9 @@ import torch +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + def get_aiter_mla_metadata(max_batch_size: int, block_size: int, max_block_per_batch: int, @@ -30,6 +33,28 @@ def aiter_mla_decode_fwd( kv_last_page_lens: Optional[torch.Tensor] = None, logit_cap: float = 0.0, ): + + torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, + kv_buffer.view( + -1, 1, 1, q.shape[-1]), + o, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap) + + +def mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: from aiter.mla import mla_decode_fwd mla_decode_fwd(q, @@ -40,3 +65,24 @@ def aiter_mla_decode_fwd( kv_last_page_lens, sm_scale=sm_scale, logit_cap=logit_cap) + + +def mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +if current_platform.is_rocm(): + direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", + op_func=mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=mla_decode_fwd_fake, + tags=[torch.Tag.needs_fixed_stride_order]) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 27af74e2e349..be4be6ed5f78 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1319,6 +1319,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "FLASHMLA", "FLASHINFER", "FLASHINFER_VLLM_V1", + "ROCM_AITER_MLA", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index acaa93f5a23e..7d7bce9ec6ab 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -145,7 +145,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( block_shape: List[int], smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - return torch.empty_like(a1, dtype=torch.bf16) + return torch.empty_like(a1, dtype=hidden_states_dtype) def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 531b13da0fa1..e9c7f0cb5835 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -39,7 +39,8 @@ class _Backend(enum.Enum): TRITON_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() - ROCM_AITER_MLA = enum.auto() + ROCM_AITER_MLA = enum.auto() # Supported by V1 + ROCM_AITER_MLA_VLLM_V1 = enum.auto() TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 8a4920303423..a0403e5325cb 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -167,10 +167,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}.") - elif selected_backend == _Backend.ROCM_AITER_MLA: + elif selected_backend == _Backend.ROCM_AITER_MLA \ + or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1: if block_size == 1: - logger.info("Using AITER MLA backend.") - return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + if use_v1: + logger.info("Using AITER MLA backend on V1 engine.") + return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + else: + logger.info("Using AITER MLA backend") + return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: raise ValueError( f" The selected backend, {selected_backend.name}," diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0d18a5639c2a..0040abeb183a 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -496,11 +496,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, max_context_chunk = (self.chunked_prefill_workspace_size // num_prefills_with_context_cpu) - # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, - self.page_size) + if self.aot_schedule: + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, + self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py new file mode 100644 index 000000000000..909e737dcea8 --- /dev/null +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +import vllm.envs as envs +from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd +# yapf conflicts with isort for this docstring +# yapf: disable +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) + +# yapf: enable + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + +@dataclass +class AiterMLADecodeMetadata(MLACommonDecodeMetadata): + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + + +class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): + pass + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + + def __init__(self, runner): + super().__init__(runner) + max_model_len = self.runner.model_config.max_model_len + assert max_model_len == 32768,\ + "AITER MLA requires max_model_len=32768" + assert self.runner.block_size == 1, "AITER MLA" \ + "only supports block size 1." + + def _get_paged_kv_tensors( + self, block_table: torch.Tensor, + seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: + page_size = self.runner.block_size + block_table_bounds = (seq_lens + page_size - 1) // page_size + + mask = (torch.arange(block_table.size(1), + dtype=block_table.dtype, + device=block_table.device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = block_table[mask] + + paged_kv_indptr = torch.cat([ + torch.zeros(1, + dtype=block_table_bounds.dtype, + device=block_table_bounds.device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32) + ]) + + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) + return ( + paged_kv_indices, + paged_kv_indptr, + paged_kv_last_page_len, + ) + + def _build_decode(self, input_positions: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + + ( + paged_kv_indices, + paged_kv_indptr, + paged_last_page_len, + ) = self._get_paged_kv_tensors(block_table, seq_lens) + + attn_metadata = AiterMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_last_page_len) + + return attn_metadata + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + return_lse=return_softmax_lse, + **kwargs, + ) + + return output + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.decode.paged_kv_indptr, + attn_metadata.decode.paged_kv_indices, + attn_metadata.decode.paged_kv_last_page_len) + + return self._v_up_proj(o)