From c717d69d102309d23095bf326b34da494ab68d10 Mon Sep 17 00:00:00 2001 From: "ygan@amd.com" Date: Fri, 19 Sep 2025 07:38:48 +0000 Subject: [PATCH 1/3] refactor attention backend for perf boost Signed-off-by: ganyi --- vllm/config/scheduler.py | 4 + vllm/platforms/rocm.py | 4 + vllm/v1/attention/backends/rocm_aiter_fa.py | 722 ++++++++++++++++---- vllm/v1/attention/backends/utils.py | 258 ++++++- vllm/v1/core/sched/scheduler.py | 41 +- vllm/v1/worker/gpu_model_runner.py | 21 +- 6 files changed, 915 insertions(+), 135 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 93002012799..1f4efb55644 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -159,6 +159,10 @@ class SchedulerConfig: structured outputs, speculative decoding, and pipeline parallelism. """ + split_prefill_from_chunk: bool = False + """Whether to split the prefill request into pure prefill and chunked prefill in a single + batch.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ecc34cb5710..785f618f84c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -346,6 +346,10 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA: + # enable the request reorder if we are using AITER MHA for calculation + vllm_config.scheduler_config.split_prefill_from_chunk = True + @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 5b56f049386..804901c91b2 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,27 +2,192 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import Optional, ClassVar import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + split_decodes_prefills_and_chunk) +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 +_CHUNK_PREFILL_TOKENS_PER_ITER_ROCM = 32 * 1024 + +KV_CACHE_LAYOUT_V0 = False + if current_platform.is_rocm(): import aiter - from vllm.triton_utils import tl, triton + # from vllm.triton_utils import tl, triton + import triton + import triton.language as tl from vllm.utils import direct_register_custom_op + from aiter.ops.triton.utils.device_info import get_num_sms + + def block_size(x, head_dim): + return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) + + def num_programs(head_dim): + return min(head_dim, get_num_sms()) + + @triton.jit + def cp_mha_gather_cache_kernel( + key_cache_ptr, # [num_blocks, num_heads, head_size / x, page_size, x] or [num_blocks, page_size, num_head, head_size] + value_cache_ptr, # [num_blocks, num_heads, head_size, page_size] or [num_blocks, page_size, num_head, head_size] + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size] + block_table_ptr, # [num_batches, max_block_num] + cu_seqlens_kv_ptr, # [num_batches + 1] + token_to_batch_ptr, # [max_cum_tokens] note: max_cum_tokens should always larger or equal than max_tokens + seq_start_ptr, # [num_batches] + k_scale_ptr, + v_scale_ptr, + num_heads, + head_size, + x, + max_block_num, + num_tokens, + DEQUANT: tl.constexpr, + PAGE_SIZE: tl.constexpr, + CACHE_FORMAT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_PRGMS: tl.constexpr + ): + bid = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + if DEQUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + + for token_id in tl.range(bid, num_tokens, NUM_PRGMS): + key_ptr_offset = key_ptr + token_id * head_size * num_heads + value_ptr_offset = value_ptr + token_id * head_size * num_heads + batch_idx = tl.load(token_to_batch_ptr + token_id) + batch_start = tl.load(seq_start_ptr + batch_idx) + token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) + batch_offset = token_id - token_start + batch_start + block_offset = batch_offset // PAGE_SIZE + block_id = tl.load(block_table_ptr + max_block_num * batch_idx + block_offset) + slot_id = batch_offset % PAGE_SIZE + + if CACHE_FORMAT == "v0": + # For kv cache layout as + # K: [num_blocks, num_heads, head_size / x, page_size, x] + # V: [num_blocks, num_heads, head_size, page_size] + key_cache_ptr_offset = key_cache_ptr + block_id * num_heads * head_size * PAGE_SIZE + slot_id * x + value_cache_ptr_offset = value_cache_ptr + block_id * num_heads * head_size * PAGE_SIZE + slot_id + # since the num_head and head_dim are not contiguous, we use two loop the iter over the data + for head in tl.range(0, num_heads): + src_head_offset = head * PAGE_SIZE * head_size + dst_head_offset = head * head_size + for i in tl.range(0, head_size, BLOCK_SIZE): + mask = (col_offsets + i) < head_size + k_offset = (col_offsets + i) // x * PAGE_SIZE * x + col_offsets % x + k_reg = tl.load(key_cache_ptr_offset + src_head_offset + k_offset, mask=mask) + v_offset = (col_offsets + i) * PAGE_SIZE + v_reg = tl.load(value_cache_ptr_offset + src_head_offset + v_offset, mask=mask) + if DEQUANT: + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + + k_reg = (k_reg.to(tl.float32) * v_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * k_scale).to(v_dtype) + + tl.store(key_ptr_offset + dst_head_offset + col_offsets, k_reg, mask=mask) + tl.store(value_ptr_offset + dst_head_offset + col_offsets, v_reg, mask=mask) + elif CACHE_FORMAT == "NHD": + # for kv cache layout as + # K: [num_blocks, page_size, num_head, head_dim] + # V: [num_blocks, page_size, num_head, head_dim] + key_cache_ptr_offset = key_cache_ptr + block_id * num_heads * head_size * PAGE_SIZE + slot_id * num_heads * head_size + value_cache_ptr_offset = value_cache_ptr + block_id * num_heads * head_size * PAGE_SIZE + slot_id * num_heads * head_size + for i in tl.range(0, head_size * num_heads, BLOCK_SIZE): + mask = (col_offsets + i) < head_size * num_heads + k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask) + v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask) + if DEQUANT: + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) + tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask) + tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask) + + + def cp_mha_gather_cache( + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + block_tables: torch.Tensor, + k_scales: float, + v_scales: float, + cu_seqlens_kv: torch.Tensor, + token_to_batch: torch.Tensor, + seq_starts: torch.Tensor, + dequant: bool, + kv_cache_layout: str, + total_tokens: int + ): + assert kv_cache_layout in ["v0", "NHD", "HND"], "kv_cache_layout only support v0, NHD, HND" + head_dim = key.shape[2] + x = 0 + assert dequant is True, "Currently, we only support gather cache with dequant" + # For k cache layout: [num_blocks, num_heads, head_dim / x, page_size, x] + if kv_cache_layout == "v0": + x = key_cache.shape[4] + num_heads = key.shape[1] + page_size = key_cache.shape[3] + assert x * key_cache.shape[2] == head_dim, "We assume your kv cache layout is [num_blocks, num_heads, head_dim/x, page_size, x], but got otherwise" + # For k cache layout: [num_blocks, num_heads, page_size, head_dim] + elif kv_cache_layout == "HND": + assert False + assert head_dim == key_cache.shape[3], "We assume your kv cache layout is [num_blocks, num_heads, page_size, head_dim], but got otherwise" + page_size = key_cache.shape[2] + num_heads = key_cache.shape[1] + elif kv_cache_layout == "NHD": + assert head_dim == key_cache.shape[3], "We assume your kv cache layout is [num_blocks, page_size, num_heads, head_dim], but got otherwise" + page_size = key_cache.shape[1] + num_heads = key_cache.shape[2] + else: + raise RuntimeError + + NUM_PRGMS = num_programs(total_tokens) + BLOCK_SIZE = block_size(key_cache, head_dim) + grid = lambda meta: (NUM_PRGMS, ) + cp_mha_gather_cache_kernel[grid]( + key_cache, + value_cache, + key, + value, + block_tables, + cu_seqlens_kv, + token_to_batch, + seq_starts, + k_scales, + v_scales, + num_heads, + head_dim, + x, + block_tables.size(1), + total_tokens, + DEQUANT=dequant, + PAGE_SIZE=page_size, + CACHE_FORMAT=kv_cache_layout, + BLOCK_SIZE=BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS + ) + @triton.jit def _vllm_layout_trans_kernel( @@ -36,6 +201,7 @@ def _vllm_layout_trans_kernel( block_table_stride_0, k_scale, v_scale, + skip_query: tl.constexpr, output_dtype: tl.constexpr, E_DIM: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -43,13 +209,14 @@ def _vllm_layout_trans_kernel( batch_idx = tl.program_id(0) block_idx = tl.program_id(1) - batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + - tl.arange(0, 2)) - batch_query_start, batch_query_end = tl.split(batch_query_indexes) - query_len = batch_query_end - batch_query_start + if skip_query: + batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + + tl.arange(0, 2)) + batch_query_start, batch_query_end = tl.split(batch_query_indexes) + query_len = batch_query_end - batch_query_start - if query_len <= 1: - return + if query_len <= 1: + return batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + tl.arange(0, 2)) @@ -124,6 +291,9 @@ def vllm_layout_trans(b_query_lens_loc, output_dtype = tl.bfloat16 else: raise ValueError(f"Unsupported output dtype: {output_dtype}") + skip_query = False + if b_query_lens_loc is None: + skip_query = True _vllm_layout_trans_kernel[grid](k_cache, v_cache, @@ -136,6 +306,7 @@ def vllm_layout_trans(b_query_lens_loc, k_scale, v_scale, output_dtype=output_dtype, + skip_query=skip_query, E_DIM=H_KV * D, BLOCK_SIZE=BLOCK_SIZE) @@ -209,9 +380,43 @@ def flash_attn_varlen_func_fake( flash_attn_varlen_func_fake, dispatch_key=current_platform.dispatch_key) -logger = init_logger(__name__) +@dataclass +class AiterFlashAttentionDecodeMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + +@dataclass +class AiterFlashAttentionPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + +@dataclass +class AiterChunkContextMetadata: + workspace: torch.Tensor + cu_seq_lens_chunk: torch.Tensor + chunk_starts: torch.Tensor + token_to_batch: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + num_chunks: int + total_token_per_batch: list[int] + + +@dataclass +class AiterFlashAttentionChunkPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + chunk_context_metadata: AiterChunkContextMetadata + @dataclass class AiterFlashAttentionMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. @@ -230,7 +435,18 @@ class AiterFlashAttentionMetadata: seq_lens: torch.Tensor slot_mapping: torch.Tensor block_table: torch.Tensor - cu_seq_lens: Optional[torch.Tensor] + + # prefill and deocde split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + num_chunk_prefills: int + num_chunk_prefill_tokens: int + + decode_metadata: Optional[AiterFlashAttentionDecodeMetadata] + pure_prefill_metadata: Optional[AiterFlashAttentionPrefillMetadata] + chunk_prefill_metadata: Optional[AiterFlashAttentionChunkPrefillMetadata] # For cascade attention. use_cascade: bool @@ -242,6 +458,8 @@ class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata]): cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + reorder_batch_threshold: ClassVar[int] = 1 + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config @@ -263,6 +481,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.aot_sliding_window: Optional[tuple[int, int]] = None self.total_tokens: int = 0 + self.chunk_prefill_workspace_size = _CHUNK_PREFILL_TOKENS_PER_ITER_ROCM * self.num_heads_kv * self.headdim + + self.chunk_prefill_workspace = torch.empty( + [2, _CHUNK_PREFILL_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim], + dtype=self.model_config.dtype, + device=device + ) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata): self.total_tokens = self.model_config.max_model_len \ @@ -276,41 +502,109 @@ def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> 'AiterFlashAttentionMetadata': - - num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.max_seq_len - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens - block_table_tensor = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - if max_query_len > 1: - # We pre-compute cumulative seq len needed for prefill attention - # here to avoid recomputing it for every layer - cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, - dtype=torch.int32, - device=seq_lens.device) - torch.cumsum(seq_lens, - dim=0, - dtype=cu_seq_lens.dtype, - out=cu_seq_lens[1:]) - num_actual_kv_tokens = int(cu_seq_lens[-1].item()) - else: - cu_seq_lens = None - num_actual_kv_tokens = 0 + + split_ret = \ + split_decodes_prefills_and_chunk(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) + + num_decodes, num_chunk_prefills, num_pure_prefills, num_decode_tokens, num_chunk_prefill_tokens, num_pure_prefill_tokens = split_ret + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + seq_lens = common_attn_metadata.seq_lens_cpu + + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + decode_metadata = None + if num_decodes > 0: + decode_metadata = AiterFlashAttentionDecodeMetadata( + max_query_len=query_lens_cpu[:num_decodes].max().item(), + min_query_len=query_lens_cpu[:num_decodes].min().item(), + max_seq_len=seq_lens[:num_decodes].max().item(), + query_start_loc=common_attn_metadata.query_start_loc[:num_decodes + 1] + ) + + pure_prefill_metadata = None + if num_pure_prefills > 0: + query_lens_for_pure_prefill = query_lens_cpu[num_decodes + num_chunk_prefills:] + query_start_loc_device = common_attn_metadata.query_start_loc[num_decodes + num_chunk_prefills:] + pure_prefill_metadata = AiterFlashAttentionPrefillMetadata( + max_query_len=query_lens_for_pure_prefill.max().item(), + min_query_len=query_lens_for_pure_prefill.min().item(), + max_seq_len=seq_lens[num_decodes + num_chunk_prefills:].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0] + ) + + chunk_prefill_metadata = None + if num_chunk_prefills > 0: + query_lens_for_chunk_prefill = query_lens_cpu[num_decodes:num_decodes + num_chunk_prefills] + seq_lens_for_chunk_prefill = common_attn_metadata.seq_lens_cpu[num_decodes: num_decodes + num_chunk_prefills] + computed_kv_lens = seq_lens_for_chunk_prefill - query_lens_for_chunk_prefill + + # allocate the equal amount of workspace for each chunk prefill request + max_context_chunk = (_CHUNK_PREFILL_TOKENS_PER_ITER_ROCM // num_chunk_prefills) + num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk) + + + chunk_starts = torch.arange(num_chunks, dtype=torch.int32).unsqueeze(1).expand(-1, num_chunk_prefills) * max_context_chunk + chunk_ends = torch.min(computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) # [num_chunks, num_chunk_prefills] + cu_seq_lens_cpu = torch.zeros([num_chunks, num_chunk_prefills + 1], dtype=torch.int32, pin_memory=True) + torch.cumsum(chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) + max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item() + + + range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :] # [num_chunks, num_chunk_prefills, max_cum_tokens] + idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None] # [num_chunks, num_chunk_prefills, max_cum_tokens] + idx_to_batch_tensor = idx_to_batch_tensor.sum(dim=1) # [num_chunks, max_cum_tokens] + token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1) + + chunk_context_metadata = AiterChunkContextMetadata( + workspace=self.chunk_prefill_workspace, + cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True), + chunk_starts=chunk_starts.to(self.device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True), + num_chunks=num_chunks, + total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist() + ) + + query_start_loc_device = common_attn_metadata.query_start_loc[num_decodes:num_decodes + num_chunk_prefills + 1] + seq_lens_device = common_attn_metadata.seq_lens[num_decodes:num_decodes + num_chunk_prefills] + cu_seq_lens = torch.zeros(num_chunk_prefills + 1, dtype=torch.int32, device=seq_lens_device.device) + torch.cumsum(seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]) + chunk_prefill_metadata = AiterFlashAttentionChunkPrefillMetadata( + max_query_len=query_lens_for_chunk_prefill.max().item(), + min_query_len=query_lens_for_chunk_prefill.min().item(), + max_seq_len=seq_lens[num_decodes:num_decodes + num_chunk_prefills].max().item(), + query_start_loc=query_start_loc_device - query_start_loc_device[0], + chunk_context_metadata=chunk_context_metadata + ) + + num_actual_kv_tokens = torch.sum(seq_lens).item() use_cascade = common_prefix_len > 0 attn_metadata = AiterFlashAttentionMetadata( - num_actual_tokens=num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, num_actual_kv_tokens=num_actual_kv_tokens, - max_query_len=max_query_len, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table_tensor, - slot_mapping=slot_mapping, - cu_seq_lens=cu_seq_lens, + max_query_len=common_attn_metadata.max_query_len, + query_start_loc=common_attn_metadata.query_start_loc, + max_seq_len=common_attn_metadata.max_seq_len, + seq_lens=common_attn_metadata.seq_lens, + block_table=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_pure_prefills, + num_prefill_tokens=num_pure_prefill_tokens, + num_chunk_prefills=num_chunk_prefills, + num_chunk_prefill_tokens=num_chunk_prefill_tokens, + decode_metadata=decode_metadata, + pure_prefill_metadata=pure_prefill_metadata, + chunk_prefill_metadata=chunk_prefill_metadata, use_cascade=use_cascade, common_prefix_len=common_prefix_len, total_tokens=self.total_tokens, @@ -369,7 +663,10 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + if KV_CACHE_LAYOUT_V0: + return (2, num_blocks, num_kv_heads, block_size, head_size) + else: + return (2, num_blocks, block_size, num_kv_heads, head_size) class AiterFlashAttentionImpl(AttentionImpl): @@ -416,9 +713,113 @@ def __init__( "encoder/decoder cross-attention " "are not implemented for " "FlashAttentionImpl") - self.sinks = sinks - if self.sinks is not None: - raise NotImplementedError("Sinks are not supported for ROCM AITER") + + + def chunk_prefill_forward( + self, + attn_metadata: AiterFlashAttentionMetadata, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int, + block_table: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: float, + v_scale: float, + ): + out, lse = aiter.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_q, + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=True + ) + chunk_context_metadata = attn_metadata.chunk_prefill_metadata.chunk_context_metadata + seq_lens = chunk_context_metadata.seq_lens + num_chunks = chunk_context_metadata.num_chunks + workspace = chunk_context_metadata.workspace + cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk + max_seqlens = chunk_context_metadata.max_seq_lens + chunk_starts = chunk_context_metadata.chunk_starts + token_to_batch = chunk_context_metadata.token_to_batch + total_token_per_batch = chunk_context_metadata.total_token_per_batch + key_fetched, value_fetched= workspace[0], workspace[1] + chunked_output = None + chunked_lse = None + for chunk_idx in range(num_chunks): + + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=cu_seqlens_kv[chunk_idx], + token_to_batch=token_to_batch[chunk_idx], + seq_starts=chunk_starts[chunk_idx], + dequant=True, + kv_cache_layout="v0" if KV_CACHE_LAYOUT_V0 else "NHD", + total_tokens=total_token_per_batch[chunk_idx], + ) + + suf_out, suf_lse = aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_kv[chunk_idx], + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlens[chunk_idx], + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=False, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=True + ) + if chunked_output is None: + chunked_output = suf_out + chunked_lse = suf_lse + else: + tmp_output = torch.empty_like(out) + tmp_lse = torch.empty_like(lse) + merge_attn_states( + output=tmp_output, + output_lse=tmp_lse, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=suf_out, + suffix_lse=suf_lse + ) + chunked_output = tmp_output + chunked_lse = tmp_lse + + merge_attn_states( + output=output, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=out, + suffix_lse=lse, + ) + def forward( self, @@ -439,7 +840,10 @@ def forward( key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] kv_cache: shape = - [2, num_blocks, block_size, num_kv_heads, head_size] + [2, num_blocks, block_size * num_kv_heads * head_size] + more specifically: + k_cache = [num_blocks, num_kv_heads, head_dim / x, block_size, x] + v_cache = [num_blocks, num_kv_heads, block_size, head_dim] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -458,6 +862,7 @@ def forward( # Profiling run. return output + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -466,7 +871,6 @@ def forward( # Minimize the PyTorch ops in this method as much as possible. # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. - num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: @@ -477,86 +881,168 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - if self.kv_cache_dtype.startswith("fp8"): - if current_platform.is_fp8_fnuz(): - key_cache = key_cache.view(torch.float8_e4m3fnuz) - value_cache = value_cache.view(torch.float8_e4m3fnuz) + if KV_CACHE_LAYOUT_V0: + num_blocks = key_cache.shape[0] + num_heads = key_cache.shape[1] + block_size = key_cache.shape[2] + head_size = key.shape[2] + x = 16 // key_cache.dtype.itemsize + + key_cache = key_cache.view([num_blocks, num_heads, head_size // x, block_size, x]) + value_cache = value_cache.view([num_blocks, num_heads, head_size, block_size]) + torch.ops._C_cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) else: - key_cache = key_cache.view(torch.float8_e4m3fn) - value_cache = value_cache.view(torch.float8_e4m3fn) - - if not attn_metadata.use_cascade: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table - - if max_seqlen_q > 1: - torch.ops.vllm.flash_attn_varlen_func( - query[:num_actual_tokens], + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, key_cache, value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + + # decode:chunk_prefill:pure_prefill + query = query[:num_actual_tokens] + key = key[:num_actual_tokens] + value = value[:num_actual_tokens] + + output_actual_tokens = output[:num_actual_tokens] + + block_table = attn_metadata.block_table + num_decodes = attn_metadata.num_decodes + num_pure_prefills = attn_metadata.num_prefills + num_chunk_prefills = attn_metadata.num_chunk_prefills + + num_pure_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + num_chunk_prefill_tokens = attn_metadata.num_chunk_prefill_tokens + if not attn_metadata.use_cascade: + + # calculate for pure prefills + if num_pure_prefills > 0: + + prefill_query = query[num_decode_tokens + num_chunk_prefill_tokens:] + prefill_key = key[num_decode_tokens + num_chunk_prefill_tokens:] + prefill_value = value[num_decode_tokens + num_chunk_prefill_tokens:] + + aiter.flash_attn_varlen_func( + q=prefill_query, + k=prefill_key, + v=prefill_value, + cu_seqlens_q=attn_metadata.pure_prefill_metadata.query_start_loc, + cu_seqlens_k=attn_metadata.pure_prefill_metadata.query_start_loc, + max_seqlen_q=attn_metadata.pure_prefill_metadata.max_query_len, + max_seqlen_k=attn_metadata.pure_prefill_metadata.max_seq_len, + min_seqlen_q=attn_metadata.pure_prefill_metadata.min_query_len, + dropout_p=0.0, softmax_scale=self.scale, - alibi_slopes=self.alibi_slopes, + causal=True, window_size=self.sliding_window, - block_table=block_table, - cu_seqlens_k=attn_metadata.cu_seq_lens, + alibi_slopes=self.alibi_slopes, + out=output_actual_tokens[num_decode_tokens + num_chunk_prefill_tokens:], + ) + + # calculate for chunk prefills + if num_chunk_prefills > 0: + chunk_prefill_querys = query[num_decode_tokens:num_decode_tokens + num_chunk_prefill_tokens] + chunk_prefill_keys = key[num_decode_tokens:num_decode_tokens + num_chunk_prefill_tokens] + chunk_prefill_values = value[num_decode_tokens:num_decode_tokens + num_chunk_prefill_tokens] + chunk_prefill_outputs = output[num_decode_tokens:num_decode_tokens + num_chunk_prefill_tokens] + self.chunk_prefill_forward( + attn_metadata=attn_metadata, + query=chunk_prefill_querys, + key=chunk_prefill_keys, + value=chunk_prefill_values, + key_cache=key_cache, + value_cache=value_cache, + output=chunk_prefill_outputs, + cu_seqlens_q=attn_metadata.chunk_prefill_metadata.query_start_loc, + max_seqlen_q=attn_metadata.chunk_prefill_metadata.max_query_len, + max_seqlen_k=attn_metadata.chunk_prefill_metadata.max_seq_len, + min_seqlen_q=attn_metadata.chunk_prefill_metadata.min_query_len, + block_table=attn_metadata.block_table[num_decodes:num_decodes + num_chunk_prefills], + slot_mapping=attn_metadata.slot_mapping[num_decodes:num_decodes + num_chunk_prefills], k_scale=layer._k_scale, v_scale=layer._v_scale, - total_tokens=attn_metadata.num_actual_kv_tokens, ) - _, num_heads, head_size = query.shape - nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 - num_seqs = seqused_k.shape[0] - max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM - - 1) // _PARTITION_SIZE_ROCM - - workspace_buffer = torch.empty( - (num_seqs * num_heads * max_num_partitions * head_size) * - nbytes_per_qo_elem + 2 * - (num_seqs * num_heads * max_num_partitions) * 4, - dtype=torch.uint8, - device=output.device, - ) - - torch.ops.aiter.paged_attention_v1( - output[:num_actual_tokens], - workspace_buffer, - query[:num_actual_tokens], - key_cache, - value_cache, - self.scale, - block_table, - cu_seqlens_q, - seqused_k, - max_seqlen_k, - self.alibi_slopes, - self.kv_cache_dtype, - "NHD", - self.logits_soft_cap, - layer._k_scale, - layer._v_scale, - None, - _PARTITION_SIZE_ROCM, - ) - return output + # calculate for decodes + if num_decodes > 0: + if KV_CACHE_LAYOUT_V0: + # ============= spec decode ================= + # kv cache layout: [num_blocks, num_heads, head_dim / x, page_size, x] + from aiter.paged_attn import PagedAttention + # for spec decode impl + decode_output = PagedAttention.forward_decode( + query[:num_decode_tokens], + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_table[:num_decode_tokens], + seq_lens=attn_metadata.seq_lens[:num_decodes], + max_seq_len=attn_metadata.decode_metadata.max_seq_len, + kv_cache_dtype=self.kv_cache_dtype, + num_kv_heads=self.num_kv_heads, + scale=self.scale, + alibi_slopes=self.alibi_slopes, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + mtp=attn_metadata.decode_metadata.max_query_len + ) + output_actual_tokens[:num_decode_tokens] = decode_output + # ============= spec decode ================= + else: + _, num_heads, head_size = query.shape + nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 + max_num_partitions = (attn_metadata.decode_metadata.max_seq_len + _PARTITION_SIZE_ROCM - + 1) // _PARTITION_SIZE_ROCM + + workspace_buffer = torch.empty( + (num_decode_tokens * num_heads * max_num_partitions * head_size) * + nbytes_per_qo_elem + 2 * + (num_decode_tokens * num_heads * max_num_partitions) * 4, + dtype=torch.uint8, + device=output.device, + ) + + torch.ops.aiter.paged_attention_v1( + output_actual_tokens[:num_decode_tokens], + workspace_buffer, + query[:num_decode_tokens], + key_cache, + value_cache, + self.scale, + attn_metadata.block_table[:num_decodes], + attn_metadata.decode_metadata.query_start_loc, + attn_metadata.seq_lens[:num_decodes], + attn_metadata.decode_metadata.max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + "NHD", + self.logits_soft_cap, + layer._k_scale, + layer._v_scale, + None, + _PARTITION_SIZE_ROCM, + ) else: raise NotImplementedError( - "Cascade attention is not implemented for ROCM AITER") \ No newline at end of file + "Cascade attention is not implemented for ROCM AITER") + + return output + diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 39bdbe12563..235361c62c3 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,8 +4,10 @@ import enum import functools from abc import abstractmethod -from dataclasses import dataclass, make_dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar +from dataclasses import dataclass, fields, make_dataclass +from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol, + TypeVar) +from collections import deque import numpy as np import torch @@ -542,6 +544,69 @@ def make_local_attention_virtual_batches( ) +def make_kv_sharing_fast_prefill_common_attn_metadata( + common_attn_metadata: CommonAttentionMetadata, +) -> CommonAttentionMetadata: + if common_attn_metadata.max_query_len == 1: + # All requests are decode (assume 1 token for now) + # Skip computing fast prefill path + return common_attn_metadata + + assert common_attn_metadata.logits_indices_padded is not None + assert common_attn_metadata.num_logits_indices is not None + + logits_indices_padded = common_attn_metadata.logits_indices_padded + num_logits_indices = common_attn_metadata.num_logits_indices + # Get rid of CUDAGraph padding, if any + logits_indices = logits_indices_padded[:num_logits_indices] + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # Example inputs + # num_reqs: 3 + # generation_indices: [14, 18, 19, 27] + # query_start_loc: [0, 15, 20, 28] + # seq_lens: [41, 31, 40] + + # Find how many decode indices belong to each request + # request_ids: [0, 1, 1, 2] + request_ids = torch.bucketize(logits_indices, + query_start_loc[1:], + right=True) + + # Figure out how many tokens are in each request + # num_decode_tokens: [1, 2, 1] + num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + + # Calculate new query_start_loc with tokens in generation_indices + # decode_query_start_loc: [0, 1, 3, 4] + decode_query_start_loc = torch.empty(num_reqs + 1, + device=query_start_loc.device, + dtype=query_start_loc.dtype) + + decode_query_start_loc[0] = 0 + decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) + decode_max_query_len = int(num_decode_tokens.max().item()) + total_num_decode_tokens = int(num_decode_tokens.sum().item()) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=decode_query_start_loc, + query_start_loc_cpu=decode_query_start_loc.to("cpu", + non_blocking=True), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.to("cpu", non_blocking=True), + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_decode_tokens, + max_query_len=decode_max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + causal=True, + ) + return common_attn_metadata + + def subclass_attention_backend( name_prefix: str, attention_backend_cls: type[AttentionBackend], builder_cls: type[AttentionMetadataBuilder[M]] @@ -555,6 +620,67 @@ def subclass_attention_backend( {"get_builder_cls": lambda: builder_cls}) +def split_decodes_prefills_and_chunk( + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int, int, int]: + """ + Assuming a reordered batch, finds the boundary between prefill and decode + requests. + + Args: + common_attn_metadata: CommonAttentionMetadata object containing the + batch metadata. + decode_threshold: The maximum query length to be considered a decode. + + Returns: + num_decodes: The number of decode requests. + num_prefills: The number of prefill requests. + num_decode_tokens: The number of tokens in the decode requests. + num_prefill_tokens: The number of tokens in the prefill requests. + """ + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens_cpu + + if max_query_len <= decode_threshold: + return num_reqs, 0, 0, num_tokens, 0, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): + return num_reqs, 0, 0, num_tokens, 0, 0 + + + first_prefill = is_prefill.int().argmax(dim=-1).item() + assert torch.all(query_lens[first_prefill:] > decode_threshold), f"got query lens: {query_lens[first_prefill:]} and decode threshold {decode_threshold}" + assert torch.all(query_lens[:first_prefill] <= decode_threshold), f"got query lens: {query_lens[:first_prefill]} and decode threshold {decode_threshold}" + num_decodes = first_prefill + num_decode_tokens = query_start_loc[first_prefill].item() + + query_lens_prefill = query_lens[first_prefill:] + seq_lens_prefill = seq_lens[first_prefill:] + is_pure_prefill = seq_lens_prefill == query_lens_prefill + + if torch.all(is_pure_prefill): + num_pure_prefills = num_reqs - num_decodes + num_pure_prefill_tokens = num_tokens - num_decode_tokens + return (num_decodes, 0, num_pure_prefills, num_decode_tokens, 0, num_pure_prefill_tokens) + + num_prefills = num_reqs - num_decodes + num_prefill_tokens = num_tokens - num_decode_tokens + first_chunk_prefill = is_pure_prefill.int().argmax(dim=-1).item() + + num_chunk_prefills = first_chunk_prefill + num_pure_prefills = num_prefills - first_chunk_prefill + + num_chunk_prefill_tokens = query_start_loc[num_chunk_prefills + num_decodes].item() - num_decode_tokens + num_pure_prefill_tokens = num_tokens - num_decode_tokens - num_chunk_prefill_tokens + return (num_decodes, num_chunk_prefills, num_pure_prefills, num_decode_tokens, num_chunk_prefill_tokens, num_pure_prefill_tokens) + + def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1, @@ -597,10 +723,138 @@ def split_decodes_and_prefills( return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) +def reorder_batch_to_split_decodes_prefills_and_chunks( + input_batch: "InputBatch", + scheduler_output: "SchedulerOutput", + decode_threshold: int = 1, +) -> bool: + """ + Reorders the batch to split into prefill, chunk_prefill and decode requests; places all + requests in the order of [decodes:chunked_prefills:pure_prefills]. + + Returns: + True if the batch was modified, False otherwise. + """ + + # We assume most of the request is already in the order of what we desired since this function + # should only be opened after the `SchedulerConfig.split_prefill_from_chunk` is True. So we only + # need to spot all mismatched request and swap their positions for efficiency. + + decodes = [] + prefills = [] + chunk_prefills = [] + + def print_order_of_batch(): + new_decode = [] + new_chunk_prefill = [] + new_prefills = [] + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if num_tokens <= decode_threshold: + new_decode.append(i) + elif input_batch.num_computed_tokens_cpu[i] > 0: + # print("found one chunk prefill request, computed token is: ", input_batch.num_computed_tokens_cpu[i]) + new_chunk_prefill.append(i) + else: + new_prefills.append(i) + print("decodes: ", new_decode) + print("append: ", new_chunk_prefill) + print("prefills: ", new_prefills) + + # print("into split d p c") + # print('before reorder') + # print_order_of_batch() + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if num_tokens <= decode_threshold: + decodes.append(i) + elif input_batch.num_computed_tokens_cpu[i] > 0: + # print("found one chunk prefill request, computed token is: ", input_batch.num_computed_tokens_cpu[i]) + chunk_prefills.append(i) + else: + prefills.append(i) + + num_decodes = len(decodes) + num_chunk_prefills = len(chunk_prefills) + # We define the reorder matrix here to help on the request reorder + # reorder_matrix[(i, j)] means the id the the requests that suppose to be in + # zone i but actually spot on zone j + # The decode, chunk prefill and pure prefill are separated into 3 different zone + # here, 0 for decode, 1 for chunk prefill and 2 for pure prefill + reorder_matrix = {(i, j): deque() for i in range(3) for j in range(3) if i!=j} + # collect mismatch + + def target_idx(idx): + if idx < num_decodes: + # decode as zone 0 + return 0 + elif idx < num_decodes + num_chunk_prefills: + # chunk prefill as zone 1 + return 1 + else: + # pure prefill as zone 2 + return 2 + + def fill_reorder_matrix(request_lists, reorder_sequence): + for idx, seq in enumerate(reorder_sequence): + request_list = request_lists[idx] + for req_idx in request_list: + req_target_id = target_idx(req_idx) + if seq != req_target_id: + reorder_matrix[(seq, req_target_id)].append(req_idx) + # print("reorder matrix: ", reorder_matrix) + + def direct_zone_swap(i, j): + assert i != j + modified_batch = False + while reorder_matrix[(i, j)] and reorder_matrix[(j, i)]: + swap_req1 = reorder_matrix[(i, j)].pop() + swap_req2 = reorder_matrix[(j, i)].pop() + input_batch.swap_states(swap_req1, swap_req2) + modified_batch = True + + return modified_batch + + # in order 1,2,3, out order 3, 1, 2 + def indirect_zone_swap(zone_list): + assert len(zone_list) == 3 + modified_batch = False + while reorder_matrix[zone_list[0]] and reorder_matrix[zone_list[1]] and reorder_matrix[zone_list[2]]: + swap_req1 = reorder_matrix[zone_list[0]].pop() + swap_req2 = reorder_matrix[zone_list[1]].pop() + swap_req3 = reorder_matrix[zone_list[2]].pop() + # print("do indirect swap: ", swap_req1, swap_req2, swap_req3) + # print("desired order should be : ", swap_req3, swap_req1, swap_req2) + + input_batch.swap_states(swap_req1, swap_req2) + input_batch.swap_states(swap_req2, swap_req3) + modified_batch = True + return modified_batch + + + fill_reorder_matrix([decodes, chunk_prefills, prefills], [0, 1, 2]) + + modified_batch = False + # do directly swap for + modified_batch &= direct_zone_swap(0, 1) # decode <--> chunk prefill + modified_batch &= direct_zone_swap(0, 2) # decode <--> pure prefill + modified_batch &= direct_zone_swap(1, 2) # chunk prefill <--> pure prefill + + modified_batch &= indirect_zone_swap(((0, 1), (1, 2), (2, 0))) + modified_batch &= indirect_zone_swap(((2, 1), (0, 2), (1, 0))) + + # print("after reorder") + # print_order_of_batch() + + return modified_batch + + + def reorder_batch_to_split_decodes_and_prefills( input_batch: "InputBatch", scheduler_output: "SchedulerOutput", decode_threshold: int = 1, + reorder_append_prefills: bool = False, ) -> bool: """ Reorders the batch to split into prefill and decode requests; places all diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3bd2fe2f051..c08c59615dc 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -177,6 +177,8 @@ def schedule(self) -> SchedulerOutput: # and the "jump decoding" optimization in the future. scheduled_new_reqs: list[Request] = [] + new_reqs_for_pure_preill: list[Request] = [] + new_reqs_for_chunk_prefill: list[Request] = [] scheduled_resumed_reqs: list[Request] = [] scheduled_running_reqs: list[Request] = [] preempted_reqs: list[Request] = [] @@ -488,18 +490,23 @@ def schedule(self) -> SchedulerOutput: request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue + if num_computed_tokens > 0 and self.scheduler_config.split_prefill_from_chunk: + new_reqs_for_chunk_prefill.append(request) + else: + new_reqs_for_pure_preill.append(request) + req_index += 1 - self.running.append(request) + # self.running.append(request) if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) - if request.status == RequestStatus.WAITING: - scheduled_new_reqs.append(request) - elif request.status == RequestStatus.PREEMPTED: - scheduled_resumed_reqs.append(request) - else: - raise RuntimeError( - f"Invalid request status: {request.status}") + # if request.status == RequestStatus.WAITING: + # scheduled_new_reqs.append(request) + # elif request.status == RequestStatus.PREEMPTED: + # scheduled_resumed_reqs.append(request) + # else: + # raise RuntimeError( + # f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) @@ -507,7 +514,7 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_blocks(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens - request.status = RequestStatus.RUNNING + # request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens # Count the number of prefix cached tokens. if request.num_cached_tokens < 0: @@ -521,6 +528,22 @@ def schedule(self) -> SchedulerOutput: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget + # reorder the request during scheduling, put chunked prefill at the top of + # the scheduled_new_reqs to make sure the actual reorder in model runner + # happens as less as possible. + new_reqs_for_chunk_prefill.extend(new_reqs_for_pure_preill) + for req in new_reqs_for_chunk_prefill: + self.running.append(req) + + if req.status == RequestStatus.WAITING: + scheduled_new_reqs.append(req) + elif req.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(req) + else: + raise RuntimeError( + f"Invalid request status: {req.status}") + req.status = RequestStatus.RUNNING + # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.prepend_requests(skipped_waiting_requests) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f79abe13f5e..6231dca8ada 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -59,7 +59,9 @@ from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, make_kv_sharing_fast_prefill_attention_metadata, - reorder_batch_to_split_decodes_and_prefills) + reorder_batch_to_split_decodes_and_prefills, + reorder_batch_to_split_decodes_prefills_and_chunks) +>>>>>>> 6ca4159b9 (refactor attention backend for perf boost) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, @@ -385,10 +387,16 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: return if self.reorder_batch_threshold is not None: - reorder_batch_to_split_decodes_and_prefills( - self.input_batch, - scheduler_output, - decode_threshold=self.reorder_batch_threshold) + if self.scheduler_config.split_prefill_from_chunk: + reorder_batch_to_split_decodes_prefills_and_chunks( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) + else: + reorder_batch_to_split_decodes_and_prefills( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) # Note: used for model runner override. def _init_device_properties(self) -> None: @@ -483,7 +491,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(req_state) - reqs_to_add.append(req_state) # Update the states of the running/resumed requests. @@ -532,6 +539,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. + reqs_to_add.append(req_state) continue @@ -570,6 +578,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: + self.input_batch.add_request(request) # Condense the batched states if there are gaps left by removed requests From a65fe1c7dc9cb3e4aa5c572fe93173473b46ff3a Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 29 Sep 2025 10:12:42 +0800 Subject: [PATCH 2/3] fix 355 gap Signed-off-by: zhuyuhua-v --- vllm/v1/attention/backends/rocm_aiter_fa.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 804901c91b2..fd5641e0858 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -32,7 +32,7 @@ import triton import triton.language as tl from vllm.utils import direct_register_custom_op - from aiter.ops.triton.utils.device_info import get_num_sms + from aiter.ops.triton.utils.arch_info import get_num_sms def block_size(x, head_dim): return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6231dca8ada..d4d87c8aa2e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -61,7 +61,6 @@ make_kv_sharing_fast_prefill_attention_metadata, reorder_batch_to_split_decodes_and_prefills, reorder_batch_to_split_decodes_prefills_and_chunks) ->>>>>>> 6ca4159b9 (refactor attention backend for perf boost) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, From 2e1757d1ee963e3423fd3ccc94cf624b4be7e8bb Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 29 Sep 2025 10:13:50 +0800 Subject: [PATCH 3/3] add dispatcher Signed-off-by: zhuyuhua-v --- vllm/platforms/rocm.py | 15 +- .../backends/rocm_mha_backend_helper.py | 128 ++++++++++++++++++ vllm/v1/attention/backends/triton_attn.py | 54 ++++---- 3 files changed, 157 insertions(+), 40 deletions(-) create mode 100644 vllm/v1/attention/backends/rocm_mha_backend_helper.py diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 785f618f84c..59cab293f12 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -239,16 +239,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if selected_backend is None or selected_backend == _Backend.FLASH_ATTN: selected_backend = _Backend.ROCM_FLASH - if envs.VLLM_USE_V1: - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \ - and on_gfx9(): - logger.info("Using Flash Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "rocm_aiter_fa.AiterFlashAttentionBackend") - else: - logger.info("Using Triton Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "triton_attn.TritonAttentionBackend") + if envs.VLLM_USE_V1: + from vllm.v1.attention.backends.rocm_mha_backend_helper import get_rocm_mha_backend_selection + backend_class_path, _ = get_rocm_mha_backend_selection() + if backend_class_path: + return backend_class_path if selected_backend == _Backend.ROCM_FLASH: if not cls.has_device_capability(90): # not Instinct series GPUs. diff --git a/vllm/v1/attention/backends/rocm_mha_backend_helper.py b/vllm/v1/attention/backends/rocm_mha_backend_helper.py new file mode 100644 index 00000000000..5db8ea272df --- /dev/null +++ b/vllm/v1/attention/backends/rocm_mha_backend_helper.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Helper functions for ROCm attention backend selection logic. +Centralized logic for choosing between different attention backends and implementations. +""" + +from typing import Optional, Tuple, Callable +from vllm import envs +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +def get_rocm_mha_backend_selection() -> Tuple[str, Optional[str]]: + """ + Centralized logic for ROCm attention backend selection. + + Returns: + tuple: (backend_class_path, unified_attention_impl_path) + - backend_class_path: Full class path for the selected backend + - unified_attention_impl_path: Path to unified attention implementation (if applicable) + + Priority order: + 1. AITER MHA: If VLLM_ROCM_USE_AITER=1 and VLLM_ROCM_USE_AITER_MHA=1 + 2. AITER Unified: If VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 and VLLM_USE_AITER_UNIFIED_ATTENTION=1 + 3. vLLM Unified: Default unified attention implementation + 4. Dynamic: Fallback to dynamic backend + """ + if not current_platform.is_rocm(): + return None, None + + # Check AITER availability + aiter_available = False + try: + import aiter # noqa: F401 + aiter_available = True + except Exception: + aiter_available = False + + # Priority 1: AITER MHA if both flags are on and available + if (envs.VLLM_ROCM_USE_AITER and + envs.VLLM_ROCM_USE_AITER_MHA and + aiter_available): + logger.info("ROCm Backend Selection: Using AITER FlashAttention backend") + return ("vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", None) + + # Priority 2: AITER unified attention for Triton + if (envs.VLLM_USE_AITER_UNIFIED_ATTENTION and + not envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION and + aiter_available): + logger.info("ROCm Backend Selection: Using Triton backend with AITER unified attention") + return ("vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", + "aiter.ops.triton.unified_attention.unified_attention") + + # Priority 3: vLLM unified attention for Triton (default) + logger.info("ROCm Backend Selection: Using Triton backend with vLLM unified attention") + return ("vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", + "vllm.attention.ops.triton_unified_attention.unified_attention") + + + +def get_unified_attention_impl() -> Optional[Callable]: + """ + Get the appropriate unified attention implementation based on environment variables. + + Returns: + Callable or None: The unified attention function to use, or None for split path + """ + if not current_platform.is_rocm(): + return None + + # Check AITER availability + aiter_available = False + try: + import aiter # noqa: F401 + aiter_available = True + except Exception: + aiter_available = False + + # Priority 1: AITER unified attention + if (envs.VLLM_USE_AITER_UNIFIED_ATTENTION and + not envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION and + aiter_available): + try: + from aiter.ops.triton.unified_attention import unified_attention + logger.info("ROCm Unified Attention: Using AITER implementation") + return unified_attention + except Exception: + pass + + # Priority 2: vLLM unified attention + else: + try: + from vllm.attention.ops.triton_unified_attention import unified_attention + logger.info("ROCm Unified Attention: Using vLLM implementation") + return unified_attention + except Exception: + pass + + # Default: use split path + logger.info("ROCm Unified Attention: Using split prefill/decode attention") + return None + + +# def should_use_aiter_mha() -> bool: +# """ +# Check if AITER MHA backend should be used. + +# Returns: +# bool: True if AITER MHA should be used +# """ +# if not current_platform.is_rocm(): +# return False + +# # Check AITER availability +# aiter_available = False +# try: +# import aiter # noqa: F401 +# aiter_available = True +# except Exception: +# aiter_available = False + +# return (envs.VLLM_ROCM_USE_AITER and +# envs.VLLM_ROCM_USE_AITER_MHA and +# aiter_available) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 533eb1cc5f7..7054e7a6764 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -200,13 +200,6 @@ def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: return TritonAttentionMetadataBuilder -@cache -def use_aiter_unified_attention() -> bool: - """Check if aiter unified attention should be used.""" - # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set - # to 1 as default - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_USE_AITER_UNIFIED_ATTENTION class TritonAttentionImpl(AttentionImpl): @@ -258,24 +251,12 @@ def __init__( self.fp8_dtype = current_platform.fp8_dtype() - # If not using prefill decode attention, we use the Triton - # unified attention implementation. - if use_aiter_unified_attention(): - logger.info_once( - "Using aiter unified attention for TritonAttentionImpl") - from aiter.ops.triton.unified_attention import unified_attention - - self.unified_attention = unified_attention - elif not envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION: - logger.info_once( - "Using vllm unified attention for TritonAttentionImpl") - from vllm.attention.ops.triton_unified_attention import ( - unified_attention) - self.unified_attention = unified_attention - else: - logger.info_once( - "Using vllm split prefill decode attention for TritonAttentionImpl" - ) + # Unified attention implementation to be provided by the caller + # at runtime (via rocm_dynamic dispatcher). + self._unified_attention = None + + # Auto-set unified attention implementation based on environment variables + self._setup_unified_attention_impl() self.sinks = sinks if sinks is not None: @@ -334,10 +315,11 @@ def forward( # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. - use_prefill_decode_attn = ( + # Runtime choice strictly controlled by rocm_dynamic dispatcher. + # If a unified attention implementation is set, use it; otherwise + # use split prefill/decode path. + use_prefill_decode_attn = (self._unified_attention is None) or \ envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - and not use_aiter_unified_attention() - ) num_actual_tokens = attn_metadata.num_actual_tokens if use_prefill_decode_attn: @@ -469,7 +451,9 @@ def forward( else: descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - self.unified_attention( + # Use the unified attention implementation provided by dispatcher + unified_impl = self._unified_attention + unified_impl( q=query[:num_actual_tokens], k=key_cache, v=value_cache, @@ -491,4 +475,14 @@ def forward( output_scale=output_scale, ) - return output + + def _setup_unified_attention_impl(self) -> None: + """Auto-setup unified attention implementation based on environment variables.""" + from vllm.v1.attention.backends.rocm_mha_backend_helper import get_unified_attention_impl + self._unified_attention = get_unified_attention_impl() + + + def set_unified_attention_impl(self, fn) -> None: + # Set the callable for unified attention, or None to force split path + self._unified_attention = fn +