From 225d86e96c72e53edc499f092d2547eb860ed700 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 30 Jun 2025 08:51:43 +0000 Subject: [PATCH 1/4] support full graph Signed-off-by: vllmellm --- .../attention/backends/mla/rocm_aiter_mla.py | 79 +++++++++++-------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 8ad4e542b45b..742a3be27a6b 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Optional import torch @@ -54,7 +54,7 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The number of entries in the last page of each request in # the paged kv cache, shape: [batch_size] paged_kv_last_page_len: Optional[torch.Tensor] = None - # The query indptr, shape : [num_decode + 1] + # # The query indptr, shape : [num_decode + 1] qo_indptr: Optional[torch.Tensor] = None @@ -63,6 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + full_cudagraph_supported: ClassVar[bool] = True # decode only + force_separate_routine: ClassVar[Optional[bool]] = True def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -70,62 +72,71 @@ def __init__(self, runner, kv_cache_spec: AttentionSpec, assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." - def _get_paged_kv_tensors( - self, block_table: torch.Tensor, - seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: + # Preparing persistent buffers + self.paged_kv_indptr = torch.zeros(self.runner.max_num_reqs + 1, + dtype=torch.int32, + device=self.runner.device) + self.paged_kv_indices = torch.zeros( + block_table.get_device_tensor().numel(), # max num pages possible + dtype=torch.int32, + device=self.runner.device) + self.paged_kv_last_page_len = torch.zeros(self.runner.max_num_reqs, + dtype=torch.int32, + device=self.runner.device) + + self.qo_indptr = torch.arange(0, self.runner.max_num_reqs + 1, + step=1, + dtype=torch.int32, + device=self.runner.device) + + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens + page_size - 1) // page_size device = self.runner.device - mask = (torch.arange(block_table.size(1), - dtype=block_table.dtype, + mask = (torch.arange(block_table_tensor.size(1), + dtype=block_table_tensor.dtype, device=device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table[mask] + paged_kv_indices = block_table_tensor[mask] + num_actual_pages = paged_kv_indices.size(0) + + self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, + non_blocking=True) + self.paged_kv_indices[num_actual_pages:].fill_(-1) + + num_reqs = self._num_decodes paged_kv_indptr = torch.cat([ torch.zeros(1, dtype=block_table_bounds.dtype, device=device), block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) + self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, + non_blocking=True) + self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) + paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) - qo_indptr = torch.arange(0, - self._num_decodes + 1, - step=1, - dtype=torch.int32, - device=device) - - return ( - paged_kv_indices, - paged_kv_indptr, - paged_kv_last_page_len, - qo_indptr, - ) - - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: - - ( - paged_kv_indices, - paged_kv_indptr, - paged_last_page_len, - qo_indptr, - ) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) + self.paged_kv_last_page_len[:num_reqs].copy_(paged_kv_last_page_len, + non_blocking=True) + self.paged_kv_last_page_len[num_reqs:].fill_(1) attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens, - paged_kv_indptr=paged_kv_indptr, - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_last_page_len, - qo_indptr=qo_indptr) + paged_kv_indptr=self.paged_kv_indptr[:1 + num_reqs], + paged_kv_indices=self.paged_kv_indices[:num_actual_pages], + paged_kv_last_page_len=self.paged_kv_last_page_len[:num_reqs], + qo_indptr=self.qo_indptr[:1 + num_reqs]) return attn_metadata class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + def __init__( self, From ca4cb5e489ac9134c0ae4deca84ca105d4054a57 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 30 Jun 2025 11:04:50 +0000 Subject: [PATCH 2/4] clean code Signed-off-by: vllmellm --- .../v1/attention/backends/mla/rocm_aiter_mla.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 742a3be27a6b..faa00425e192 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import Any, ClassVar, Optional import torch @@ -63,8 +63,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - full_cudagraph_supported: ClassVar[bool] = True # decode only - force_separate_routine: ClassVar[Optional[bool]] = True + full_cudagraph_supported: ClassVar[bool] = True # decode only def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -84,10 +83,11 @@ def __init__(self, runner, kv_cache_spec: AttentionSpec, dtype=torch.int32, device=self.runner.device) - self.qo_indptr = torch.arange(0, self.runner.max_num_reqs + 1, - step=1, - dtype=torch.int32, - device=self.runner.device) + self.qo_indptr = torch.arange(0, + self.runner.max_num_reqs + 1, + step=1, + dtype=torch.int32, + device=self.runner.device) def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: @@ -101,7 +101,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table_tensor[mask] num_actual_pages = paged_kv_indices.size(0) - + self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, non_blocking=True) self.paged_kv_indices[num_actual_pages:].fill_(-1) @@ -136,7 +136,6 @@ def _build_decode(self, block_table_tensor: torch.Tensor, class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): - def __init__( self, From bf8f04474090c89fa401b27ad298356790c3192c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Mon, 30 Jun 2025 13:07:26 +0000 Subject: [PATCH 3/4] remove the repeated # Signed-off-by: vllmellm --- vllm/v1/attention/backends/mla/rocm_aiter_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index faa00425e192..eab19cb61c0f 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -54,7 +54,7 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The number of entries in the last page of each request in # the paged kv cache, shape: [batch_size] paged_kv_last_page_len: Optional[torch.Tensor] = None - # # The query indptr, shape : [num_decode + 1] + # The query indptr, shape : [num_decode + 1] qo_indptr: Optional[torch.Tensor] = None From 351587739687d863240062716f30717ba0fd4205 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 2 Jul 2025 08:26:16 +0000 Subject: [PATCH 4/4] preserve metadata state of piecewise graph to avoid overhead performance caused by full graph metadata tesnor buffers Signed-off-by: vllmellm --- .../attention/backends/mla/rocm_aiter_mla.py | 88 +++++++++++-------- 1 file changed, 53 insertions(+), 35 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index eab19cb61c0f..d5f9dfaea065 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -72,22 +72,25 @@ def __init__(self, runner, kv_cache_spec: AttentionSpec, "only supports block size 1." # Preparing persistent buffers - self.paged_kv_indptr = torch.zeros(self.runner.max_num_reqs + 1, - dtype=torch.int32, - device=self.runner.device) - self.paged_kv_indices = torch.zeros( - block_table.get_device_tensor().numel(), # max num pages possible - dtype=torch.int32, - device=self.runner.device) - self.paged_kv_last_page_len = torch.zeros(self.runner.max_num_reqs, - dtype=torch.int32, - device=self.runner.device) - - self.qo_indptr = torch.arange(0, - self.runner.max_num_reqs + 1, - step=1, - dtype=torch.int32, - device=self.runner.device) + if self.runner.full_cuda_graph: + device = self.runner.device + max_num_reqs = self.runner.max_num_reqs + self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, + dtype=torch.int32, + device=device) + self.paged_kv_indices = torch.zeros( + block_table.get_device_tensor().numel( + ), # max num pages possible + dtype=torch.int32, + device=device) + self.paged_kv_last_page_len = torch.zeros(max_num_reqs, + dtype=torch.int32, + device=device) + + self.qo_indptr = torch.arange(0, + max_num_reqs + 1, + dtype=torch.int32, + device=device) def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: @@ -100,37 +103,52 @@ def _build_decode(self, block_table_tensor: torch.Tensor, device=device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table_tensor[mask] - num_actual_pages = paged_kv_indices.size(0) - self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, - non_blocking=True) - self.paged_kv_indices[num_actual_pages:].fill_(-1) - - num_reqs = self._num_decodes + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) paged_kv_indptr = torch.cat([ torch.zeros(1, dtype=block_table_bounds.dtype, device=device), block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) - self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, - non_blocking=True) - self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) + if self.runner.full_cuda_graph: + num_reqs = self._num_decodes - paged_kv_last_page_len = seq_lens % page_size - paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, - page_size, paged_kv_last_page_len) - self.paged_kv_last_page_len[:num_reqs].copy_(paged_kv_last_page_len, - non_blocking=True) - self.paged_kv_last_page_len[num_reqs:].fill_(1) + num_actual_pages = paged_kv_indices.size(0) + + self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, + non_blocking=True) + self.paged_kv_indices[num_actual_pages:].fill_(-1) + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + + self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, + non_blocking=True) + self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) + paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs] + + self.paged_kv_last_page_len[:num_reqs].copy_( + paged_kv_last_page_len, non_blocking=True) + self.paged_kv_last_page_len[num_reqs:].fill_(1) + paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] + + qo_indptr = self.qo_indptr[:1 + num_reqs] + + else: + qo_indptr = torch.arange(0, + self._num_decodes + 1, + step=1, + dtype=torch.int32, + device=device) attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens, - paged_kv_indptr=self.paged_kv_indptr[:1 + num_reqs], - paged_kv_indices=self.paged_kv_indices[:num_actual_pages], - paged_kv_last_page_len=self.paged_kv_last_page_len[:num_reqs], - qo_indptr=self.qo_indptr[:1 + num_reqs]) + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + qo_indptr=qo_indptr) return attn_metadata