diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index daf094d2df5c..c1d312f9ef78 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -142,6 +142,10 @@ class SchedulerConfig: structured outputs, speculative decoding, and pipeline parallelism. """ + split_prefill_from_chunk: bool = False + """Whether to split the prefill request into pure prefill and chunked + prefill in a single batch.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1dacd026b667..9439a72123b2 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -343,6 +343,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: and "-rms_norm" not in compilation_config.custom_ops): compilation_config.custom_ops.append("+rms_norm") + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA: + # enable the request reorder if we are using AITER MHA + # for calculation + vllm_config.scheduler_config.split_prefill_from_chunk = True + @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 96f8e92a2039..66d93eef9475 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,206 +2,194 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import ClassVar, Optional import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import cdiv from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + split_decodes_prefills_and_chunk) from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 +_CP_TOKENS_PER_ITER_ROCM = 32 * 1024 if current_platform.is_rocm(): import aiter + # from vllm.triton_utils import tl, triton + import triton + import triton.language as tl + from aiter.ops.triton.utils.device_info import get_num_sms - from vllm.triton_utils import tl, triton - from vllm.utils import direct_register_custom_op + def block_size(x, head_dim): + return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) + + def num_programs(head_dim): + return min(head_dim, get_num_sms()) @triton.jit - def _vllm_layout_trans_kernel( - k_buffer_ptr, - v_buffer_ptr, - k_values_ptr, - v_values_ptr, - b_query_lens_loc, - b_seq_lens_loc, - block_table, - block_table_stride_0, - k_scale, - v_scale, - output_dtype: tl.constexpr, - E_DIM: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - ): - batch_idx = tl.program_id(0) - block_idx = tl.program_id(1) - - batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + - tl.arange(0, 2)) - batch_query_start, batch_query_end = tl.split(batch_query_indexes) - query_len = batch_query_end - batch_query_start - - if query_len <= 1: - return - - batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + - tl.arange(0, 2)) - batch_token_start, batch_token_end = tl.split(batch_token_indexes) - seq_len = batch_token_end - batch_token_start - - if block_idx * BLOCK_SIZE < seq_len: - block_mask = (block_idx * BLOCK_SIZE + - tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len - - kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 + - block_idx).to(tl.int64) - - kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange( - 0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :] - k_vals = tl.load(k_buffer_ptr + kv_buffer_off, - mask=block_mask, - other=0.0) - if k_vals.dtype.is_fp8(): - k_vals = (k_vals.to(tl.float32) * - tl.load(k_scale)).to(output_dtype) - else: - k_vals = k_vals.to(output_dtype) - - v_vals = tl.load(v_buffer_ptr + kv_buffer_off, - mask=block_mask, - other=0.0) - if v_vals.dtype.is_fp8(): - v_vals = (v_vals.to(tl.float32) * - tl.load(v_scale)).to(output_dtype) - else: - v_vals = v_vals.to(output_dtype) - kv_values_off = batch_token_start * E_DIM + \ - block_idx * BLOCK_SIZE * E_DIM + \ - tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \ - tl.arange(0, E_DIM)[None, :] - tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask) - tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask) - - def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table, - k_cache, v_cache, max_seq_len, k_scale, v_scale, - output_dtype, total_tokens): - H_KV = v_cache.shape[2] - D = v_cache.shape[3] - BLOCK_SIZE = v_cache.shape[1] - - k_values = torch.empty( - (total_tokens, H_KV, D), - dtype=output_dtype, - device=k_cache.device, - ) - v_values = torch.empty( - (total_tokens, H_KV, D), - dtype=output_dtype, - device=v_cache.device, - ) + def cp_mha_gather_cache_kernel( + key_cache_ptr, # [num_blocks, page_size, num_head, head_size] + value_cache_ptr, # [num_blocks, page_size, num_head, head_size] + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size] + block_table_ptr, # [num_batches, max_block_num] + cu_seqlens_kv_ptr, # [num_batches + 1] + token_to_batch_ptr, # [max_cum_tokens] + seq_start_ptr, # [num_batches] + k_scale_ptr, + v_scale_ptr, + num_heads, + head_size, + x, + max_block_num, + num_tokens, + DEQUANT: tl.constexpr, + PAGE_SIZE: tl.constexpr, + CACHE_FORMAT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + NUM_PRGMS: tl.constexpr): + bid = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) + if DEQUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + + for token_id in tl.range(bid, num_tokens, NUM_PRGMS): + key_ptr_offset = key_ptr + token_id * head_size * num_heads + value_ptr_offset = value_ptr + token_id * head_size * num_heads + batch_idx = tl.load(token_to_batch_ptr + token_id) + batch_start = tl.load(seq_start_ptr + batch_idx) + token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) + batch_offset = token_id - token_start + batch_start + block_offset = batch_offset // PAGE_SIZE + block_id = tl.load(block_table_ptr + max_block_num * batch_idx + + block_offset) + slot_id = batch_offset % PAGE_SIZE + + if CACHE_FORMAT == "NHD": + # for kv cache layout as + # K: [num_blocks, page_size, num_head, head_dim] + # V: [num_blocks, page_size, num_head, head_dim] + key_cache_ptr_offset = key_cache_ptr + block_id * num_heads * \ + head_size * PAGE_SIZE + slot_id * num_heads * head_size + value_cache_ptr_offset = value_cache_ptr + block_id * \ + num_heads * head_size * PAGE_SIZE + slot_id * \ + num_heads * head_size + + for i in tl.range(0, head_size * num_heads, BLOCK_SIZE): + mask = (col_offsets + i) < head_size * num_heads + k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, + mask=mask) + v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, + mask=mask) + if DEQUANT: + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) + tl.store(key_ptr_offset + col_offsets + i, + k_reg, + mask=mask) + tl.store(value_ptr_offset + col_offsets + i, + v_reg, + mask=mask) + + def cp_mha_gather_cache(key_cache: torch.Tensor, value_cache: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + block_tables: torch.Tensor, k_scales: float, + v_scales: float, cu_seqlens_kv: torch.Tensor, + token_to_batch: torch.Tensor, + seq_starts: torch.Tensor, dequant: bool, + kv_cache_layout: str, total_tokens: int): + assert kv_cache_layout in [ + "v0", "NHD", "HND" + ], "kv_cache_layout only support v0, NHD, HND" + head_dim = key.shape[2] + x = 0 + assert dequant is True, "Currently, we only support "\ + "gather cache with dequant" + # For k cache layout: [num_blocks, num_heads, page_size, head_dim] + assert kv_cache_layout == "NHD", "ROCM_AITER_FA_BACKEND Only "\ + "support NHD kv cache layout for now" + assert head_dim == key_cache.shape[ + 3], "We assume your kv cache layout is [num_blocks, "\ + "page_size, num_heads, head_dim], but got otherwise" + page_size = key_cache.shape[1] + num_heads = key_cache.shape[2] + + NUM_PRGMS = num_programs(total_tokens) + BLOCK_SIZE = block_size(key_cache, head_dim) + grid = lambda meta: (NUM_PRGMS, ) + cp_mha_gather_cache_kernel[grid](key_cache, + value_cache, + key, + value, + block_tables, + cu_seqlens_kv, + token_to_batch, + seq_starts, + k_scales, + v_scales, + num_heads, + head_dim, + x, + block_tables.size(1), + total_tokens, + DEQUANT=dequant, + PAGE_SIZE=page_size, + CACHE_FORMAT=kv_cache_layout, + BLOCK_SIZE=BLOCK_SIZE, + NUM_PRGMS=NUM_PRGMS) - grid = (block_table.shape[0], - (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) - if output_dtype == torch.float16: - output_dtype = tl.float16 - elif output_dtype == torch.bfloat16: - output_dtype = tl.bfloat16 - else: - raise ValueError(f"Unsupported output dtype: {output_dtype}") - - _vllm_layout_trans_kernel[grid](k_cache, - v_cache, - k_values, - v_values, - b_query_lens_loc, - b_seq_lens_loc, - block_table, - block_table.stride(0), - k_scale, - v_scale, - output_dtype=output_dtype, - E_DIM=H_KV * D, - BLOCK_SIZE=BLOCK_SIZE) - - return k_values, v_values - - def flash_attn_varlen_func_impl( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - out: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - window_size: Optional[list[int]], # -1 means infinite context window - alibi_slopes: Optional[list[float]], - block_table: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - total_tokens: int = 0, - ) -> torch.Tensor: - if total_tokens == 0: - total_tokens = int(cu_seqlens_k[-1].item()) - k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table, - k_cache, v_cache, max_seqlen_k, k_scale, - v_scale, q.dtype, total_tokens) - - output = aiter.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - min_seqlen_q=1, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=softmax_scale, - causal=True, - alibi_slopes=alibi_slopes, - window_size=window_size, - out=out, - ) - return output +logger = init_logger(__name__) - def flash_attn_varlen_func_fake( - q: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - out: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: float, - window_size: Optional[list[int]], # -1 means infinite context window - alibi_slopes: Optional[list[float]], - block_table: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - total_tokens: int = 0, - ) -> torch.Tensor: - return torch.empty(q.shape[0], - q.shape[1], - v_cache.shape[-2], - dtype=q.dtype, - device=q.device) - direct_register_custom_op("flash_attn_varlen_func", - flash_attn_varlen_func_impl, ["out"], - flash_attn_varlen_func_fake, - dispatch_key=current_platform.dispatch_key) +@dataclass +class AiterFlashAttentionDecodeMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + + +@dataclass +class AiterFlashAttentionPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor -logger = init_logger(__name__) + +@dataclass +class AiterChunkContextMetadata: + workspace: torch.Tensor + cu_seq_lens_chunk: torch.Tensor + chunk_starts: torch.Tensor + token_to_batch: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + seq_lens: torch.Tensor + num_chunks: int + total_token_per_batch: list[int] + + +@dataclass +class AiterFlashAttentionChunkPrefillMetadata: + max_query_len: int + min_query_len: int + max_seq_len: int + query_start_loc: torch.Tensor + chunk_context_metadata: AiterChunkContextMetadata @dataclass @@ -222,7 +210,18 @@ class AiterFlashAttentionMetadata: seq_lens: torch.Tensor slot_mapping: torch.Tensor block_table: torch.Tensor - cu_seq_lens: Optional[torch.Tensor] + + # prefill and deocde split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + num_chunk_prefills: int + num_chunk_prefill_tokens: int + + decode_metadata: Optional[AiterFlashAttentionDecodeMetadata] + pure_prefill_metadata: Optional[AiterFlashAttentionPrefillMetadata] + chunk_prefill_metadata: Optional[AiterFlashAttentionChunkPrefillMetadata] # For cascade attention. use_cascade: bool @@ -232,7 +231,10 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata]): - cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + cudagraph_support: ClassVar[ + AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -253,6 +255,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.aot_sliding_window: Optional[tuple[int, int]] = None self.total_tokens: int = 0 + self.chunk_prefill_workspace_size = \ + _CP_TOKENS_PER_ITER_ROCM * self.num_heads_kv * self.headdim + + self.chunk_prefill_workspace = torch.empty( + [2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim], + dtype=self.model_config.dtype, + device=device) + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata): self.total_tokens = self.model_config.max_model_len \ @@ -267,44 +277,136 @@ def build(self, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> 'AiterFlashAttentionMetadata': - num_actual_tokens = common_attn_metadata.num_actual_tokens - max_query_len = common_attn_metadata.max_query_len - max_seq_len = common_attn_metadata.max_seq_len - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens - block_table_tensor = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - if max_query_len > 1: - # We pre-compute cumulative seq len needed for prefill attention - # here to avoid recomputing it for every layer - cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, + split_ret = \ + split_decodes_prefills_and_chunk(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) + + num_decodes, num_chunk_prefills, num_pure_prefills, num_decode_tokens,\ + num_chunk_prefill_tokens, num_pure_prefill_tokens = split_ret + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + seq_lens = common_attn_metadata.seq_lens_cpu + + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + decode_metadata = None + if num_decodes > 0: + decode_metadata = AiterFlashAttentionDecodeMetadata( + max_query_len=query_lens_cpu[:num_decodes].max().item(), + min_query_len=query_lens_cpu[:num_decodes].min().item(), + max_seq_len=seq_lens[:num_decodes].max().item(), + query_start_loc=common_attn_metadata. + query_start_loc[:num_decodes + 1]) + + pure_prefill_metadata = None + if num_pure_prefills > 0: + query_lens_for_pure_prefill = query_lens_cpu[num_decodes + + num_chunk_prefills:] + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes + num_chunk_prefills:] + pure_prefill_metadata = AiterFlashAttentionPrefillMetadata( + max_query_len=query_lens_for_pure_prefill.max().item(), + min_query_len=query_lens_for_pure_prefill.min().item(), + max_seq_len=seq_lens[num_decodes + + num_chunk_prefills:].max().item(), + query_start_loc=query_start_loc_device - + query_start_loc_device[0]) + + chunk_prefill_metadata = None + if num_chunk_prefills > 0: + query_lens_for_chunk_prefill = query_lens_cpu[ + num_decodes:num_decodes + num_chunk_prefills] + seq_lens_for_chunk_prefill = common_attn_metadata.seq_lens_cpu[ + num_decodes:num_decodes + num_chunk_prefills] + computed_kv_lens = seq_lens_for_chunk_prefill - \ + query_lens_for_chunk_prefill + + # allocate the equal amount of workspace for + # each chunk prefill request + max_context_chunk = (_CP_TOKENS_PER_ITER_ROCM // + num_chunk_prefills) + num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk) + + chunk_starts = torch.arange( + num_chunks, dtype=torch.int32).unsqueeze(1).expand( + -1, num_chunk_prefills) * max_context_chunk + chunk_ends = torch.min(computed_kv_lens.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp( + min=0) # [num_chunks, num_chunk_prefills] + cu_seq_lens_cpu = torch.zeros([num_chunks, num_chunk_prefills + 1], + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item() + + range_idx = torch.arange(max_cum_tokens, + dtype=torch.int32)[None, None, :] + idx_to_batch_tensor = \ + range_idx == cu_seq_lens_cpu[:,1:][:, :,None] + idx_to_batch_tensor = idx_to_batch_tensor.sum( + dim=1) # [num_chunks, max_cum_tokens] + token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1) + + chunk_context_metadata = AiterChunkContextMetadata( + workspace=self.chunk_prefill_workspace, + cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, + non_blocking=True), + chunk_starts=chunk_starts.to(self.device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, + token_to_batch=token_to_batch_tensor.to(self.device, + non_blocking=True), + num_chunks=num_chunks, + total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist()) + + query_start_loc_device = common_attn_metadata.query_start_loc[ + num_decodes:num_decodes + num_chunk_prefills + 1] + seq_lens_device = common_attn_metadata.seq_lens[ + num_decodes:num_decodes + num_chunk_prefills] + cu_seq_lens = torch.zeros(num_chunk_prefills + 1, dtype=torch.int32, - device=seq_lens.device) - torch.cumsum(seq_lens, + device=seq_lens_device.device) + torch.cumsum(seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]) - num_actual_kv_tokens = int(cu_seq_lens[-1].item()) - else: - cu_seq_lens = None - num_actual_kv_tokens = 0 + chunk_prefill_metadata = AiterFlashAttentionChunkPrefillMetadata( + max_query_len=query_lens_for_chunk_prefill.max().item(), + min_query_len=query_lens_for_chunk_prefill.min().item(), + max_seq_len=seq_lens[num_decodes:num_decodes + + num_chunk_prefills].max().item(), + query_start_loc=query_start_loc_device - + query_start_loc_device[0], + chunk_context_metadata=chunk_context_metadata) - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): - return None + num_actual_kv_tokens = torch.sum(seq_lens).item() use_cascade = common_prefix_len > 0 attn_metadata = AiterFlashAttentionMetadata( - num_actual_tokens=num_actual_tokens, + num_actual_tokens=common_attn_metadata.num_actual_tokens, num_actual_kv_tokens=num_actual_kv_tokens, - max_query_len=max_query_len, - query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table_tensor, - slot_mapping=slot_mapping, - cu_seq_lens=cu_seq_lens, + max_query_len=common_attn_metadata.max_query_len, + query_start_loc=common_attn_metadata.query_start_loc, + max_seq_len=common_attn_metadata.max_seq_len, + seq_lens=common_attn_metadata.seq_lens, + block_table=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_pure_prefills, + num_prefill_tokens=num_pure_prefill_tokens, + num_chunk_prefills=num_chunk_prefills, + num_chunk_prefill_tokens=num_chunk_prefill_tokens, + decode_metadata=decode_metadata, + pure_prefill_metadata=pure_prefill_metadata, + chunk_prefill_metadata=chunk_prefill_metadata, use_cascade=use_cascade, common_prefix_len=common_prefix_len, total_tokens=self.total_tokens, @@ -363,6 +465,7 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) @@ -410,6 +513,108 @@ def __init__( "are not implemented for " "FlashAttentionImpl") + def chunk_prefill_forward( + self, + attn_metadata: AiterFlashAttentionMetadata, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + min_seqlen_q: int, + block_table: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: float, + v_scale: float, + ): + out, lse = aiter.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_q, + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=True) + assert attn_metadata.chunk_prefill_metadata is not None + chunk_context_metadata = \ + attn_metadata.chunk_prefill_metadata.chunk_context_metadata + num_chunks = chunk_context_metadata.num_chunks + workspace = chunk_context_metadata.workspace + cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk + max_seqlens = chunk_context_metadata.max_seq_lens + chunk_starts = chunk_context_metadata.chunk_starts + token_to_batch = chunk_context_metadata.token_to_batch + total_token_per_batch = chunk_context_metadata.total_token_per_batch + key_fetched, value_fetched = workspace[0], workspace[1] + chunked_output = None + chunked_lse = None + for chunk_idx in range(num_chunks): + + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=cu_seqlens_kv[chunk_idx], + token_to_batch=token_to_batch[chunk_idx], + seq_starts=chunk_starts[chunk_idx], + dequant=True, + kv_cache_layout="NHD", + total_tokens=total_token_per_batch[chunk_idx], + ) + + suf_out, suf_lse = aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_kv[chunk_idx], + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlens[chunk_idx], + min_seqlen_q=min_seqlen_q, + dropout_p=0.0, + softmax_scale=self.scale, + causal=False, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=True) + if chunked_output is None: + chunked_output = suf_out + chunked_lse = suf_lse + else: + tmp_output = torch.empty_like(out) + tmp_lse = torch.empty_like(lse) + merge_attn_states(output=tmp_output, + output_lse=tmp_lse, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=suf_out, + suffix_lse=suf_lse) + chunked_output = tmp_output + chunked_lse = tmp_lse + + merge_attn_states( + output=output, + prefix_output=chunked_output, + prefix_lse=chunked_lse, + suffix_output=out, + suffix_lse=lse, + ) + def forward( self, layer: torch.nn.Module, @@ -449,24 +654,25 @@ def forward( return output # IMPORTANT! - # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in - # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead - # in this method. For example, `view` and `slice` (or `[:n]`) operations - # are surprisingly slow even in the case they do not invoke any GPU ops. + # NOTE(woosuk): With piece-wise CUDA graphs, this method is + # executed in eager-mode PyTorch. Thus, we need to be careful + # about any CPU overhead in this method. For example, `view` + # and `slice` (or `[:n]`) operations are surprisingly slow even + # in the case they do not invoke any GPU ops. # Minimize the PyTorch ops in this method as much as possible. # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. - num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. + # NOTE(woosuk): Here, key and value are padded while slot_mapping + # is not padded. However, we don't need to do + # key[:num_actual_tokens] and value[:num_actual_tokens] because + # the reshape_and_cache_flash op uses the slot_mapping's shape + # to determine the number of actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, @@ -482,67 +688,132 @@ def forward( key_cache = key_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype()) + # decode:chunk_prefill:pure_prefill + query = query[:num_actual_tokens] + key = key[:num_actual_tokens] + value = value[:num_actual_tokens] + + output_actual_tokens = output[:num_actual_tokens] + + num_decodes = attn_metadata.num_decodes + num_pure_prefills = attn_metadata.num_prefills + num_chunk_prefills = attn_metadata.num_chunk_prefills + + num_decode_tokens = attn_metadata.num_decode_tokens + num_chunk_prefill_tokens = attn_metadata.num_chunk_prefill_tokens if not attn_metadata.use_cascade: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table - - if max_seqlen_q > 1: - torch.ops.vllm.flash_attn_varlen_func( - query[:num_actual_tokens], - key_cache, - value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, + + # calculate for pure prefills + if num_pure_prefills > 0: + assert attn_metadata.pure_prefill_metadata is not None + + prefill_query = query[num_decode_tokens + + num_chunk_prefill_tokens:] + prefill_key = key[num_decode_tokens + + num_chunk_prefill_tokens:] + prefill_value = value[num_decode_tokens + + num_chunk_prefill_tokens:] + + aiter.flash_attn_varlen_func( + q=prefill_query, + k=prefill_key, + v=prefill_value, + cu_seqlens_q=attn_metadata.pure_prefill_metadata. + query_start_loc, + cu_seqlens_k=attn_metadata.pure_prefill_metadata. + query_start_loc, + max_seqlen_q=attn_metadata.pure_prefill_metadata. + max_query_len, + max_seqlen_k=attn_metadata.pure_prefill_metadata. + max_seq_len, + min_seqlen_q=attn_metadata.pure_prefill_metadata. + min_query_len, + dropout_p=0.0, softmax_scale=self.scale, - alibi_slopes=self.alibi_slopes, + causal=True, window_size=self.sliding_window, - block_table=block_table, - cu_seqlens_k=attn_metadata.cu_seq_lens, + alibi_slopes=self.alibi_slopes, + out=output_actual_tokens[num_decode_tokens + + num_chunk_prefill_tokens:], + ) + + # calculate for chunk prefills + if num_chunk_prefills > 0: + assert attn_metadata.chunk_prefill_metadata is not None + chunk_prefill_querys = query[ + num_decode_tokens:num_decode_tokens + + num_chunk_prefill_tokens] + chunk_prefill_keys = key[num_decode_tokens:num_decode_tokens + + num_chunk_prefill_tokens] + chunk_prefill_values = value[ + num_decode_tokens:num_decode_tokens + + num_chunk_prefill_tokens] + chunk_prefill_outputs = output[ + num_decode_tokens:num_decode_tokens + + num_chunk_prefill_tokens] + self.chunk_prefill_forward( + attn_metadata=attn_metadata, + query=chunk_prefill_querys, + key=chunk_prefill_keys, + value=chunk_prefill_values, + key_cache=key_cache, + value_cache=value_cache, + output=chunk_prefill_outputs, + cu_seqlens_q=attn_metadata.chunk_prefill_metadata. + query_start_loc, + max_seqlen_q=attn_metadata.chunk_prefill_metadata. + max_query_len, + max_seqlen_k=attn_metadata.chunk_prefill_metadata. + max_seq_len, + min_seqlen_q=attn_metadata.chunk_prefill_metadata. + min_query_len, + block_table=attn_metadata. + block_table[num_decodes:num_decodes + num_chunk_prefills], + slot_mapping=attn_metadata. + slot_mapping[num_decodes:num_decodes + num_chunk_prefills], k_scale=layer._k_scale, v_scale=layer._v_scale, - total_tokens=attn_metadata.num_actual_kv_tokens, ) - _, num_heads, head_size = query.shape - nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 - num_seqs = seqused_k.shape[0] - max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM - - 1) // _PARTITION_SIZE_ROCM - - workspace_buffer = torch.empty( - (num_seqs * num_heads * max_num_partitions * head_size) * - nbytes_per_qo_elem + 2 * - (num_seqs * num_heads * max_num_partitions) * 4, - dtype=torch.uint8, - device=output.device, - ) + # calculate for decodes + if num_decodes > 0: + assert attn_metadata.decode_metadata is not None + _, num_heads, head_size = query.shape + nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 + max_num_partitions = ( + attn_metadata.decode_metadata.max_seq_len + + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM + + workspace_buffer = torch.empty( + (num_decode_tokens * num_heads * max_num_partitions * + head_size) * nbytes_per_qo_elem + 2 * + (num_decode_tokens * num_heads * max_num_partitions) * 4, + dtype=torch.uint8, + device=output.device, + ) - torch.ops.aiter.paged_attention_v1( - output[:num_actual_tokens], - workspace_buffer, - query[:num_actual_tokens], - key_cache, - value_cache, - self.scale, - block_table, - cu_seqlens_q, - seqused_k, - max_seqlen_k, - self.alibi_slopes, - self.kv_cache_dtype, - "NHD", - self.logits_soft_cap, - layer._k_scale, - layer._v_scale, - None, - _PARTITION_SIZE_ROCM, - ) - return output + torch.ops.aiter.paged_attention_v1( + output_actual_tokens[:num_decode_tokens], + workspace_buffer, + query[:num_decode_tokens], + key_cache, + value_cache, + self.scale, + attn_metadata.block_table[:num_decodes], + attn_metadata.decode_metadata.query_start_loc, + attn_metadata.seq_lens[:num_decodes], + attn_metadata.decode_metadata.max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + "NHD", + self.logits_soft_cap, + layer._k_scale, + layer._v_scale, + None, + _PARTITION_SIZE_ROCM, + ) else: raise NotImplementedError( "Cascade attention is not implemented for ROCM AITER") + + return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index f37a829f401c..8322e716d7d0 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,6 +4,7 @@ import enum import functools from abc import abstractmethod +from collections import deque from dataclasses import dataclass, fields, make_dataclass from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional, Protocol, TypeVar, Union, get_args) @@ -718,6 +719,70 @@ def subclass_attention_backend( {"get_builder_cls": lambda: builder_cls}) +def split_decodes_prefills_and_chunk( + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int, int, int]: + """ + Assuming a reordered batch, finds the boundary between prefill and decode + requests. + + Args: + common_attn_metadata: CommonAttentionMetadata object containing the + batch metadata. + decode_threshold: The maximum query length to be considered a decode. + + Returns: + num_decodes: The number of decode requests. + num_prefills: The number of prefill requests. + num_decode_tokens: The number of tokens in the decode requests. + num_prefill_tokens: The number of tokens in the prefill requests. + """ + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + seq_lens = common_attn_metadata.seq_lens_cpu + + if max_query_len <= decode_threshold: + return num_reqs, 0, 0, num_tokens, 0, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): + return num_reqs, 0, 0, num_tokens, 0, 0 + + first_prefill = is_prefill.int().argmax(dim=-1).item() + assert torch.all(query_lens[first_prefill:] > decode_threshold) + assert torch.all(query_lens[:first_prefill] <= decode_threshold) + num_decodes = first_prefill + num_decode_tokens = query_start_loc[first_prefill].item() + + query_lens_prefill = query_lens[first_prefill:] + seq_lens_prefill = seq_lens[first_prefill:] + is_pure_prefill = seq_lens_prefill == query_lens_prefill + + if torch.all(is_pure_prefill): + num_pure_prefills = num_reqs - num_decodes + num_pure_prefill_tokens = num_tokens - num_decode_tokens + return (num_decodes, 0, num_pure_prefills, num_decode_tokens, 0, + num_pure_prefill_tokens) + + num_prefills = num_reqs - num_decodes + first_chunk_prefill = is_pure_prefill.int().argmax(dim=-1).item() + + num_chunk_prefills = first_chunk_prefill + num_pure_prefills = num_prefills - first_chunk_prefill + + num_chunk_prefill_tokens = query_start_loc[ + num_chunk_prefills + num_decodes].item() - num_decode_tokens + num_pure_prefill_tokens = num_tokens - num_decode_tokens - \ + num_chunk_prefill_tokens + return (num_decodes, num_chunk_prefills, num_pure_prefills, + num_decode_tokens, num_chunk_prefill_tokens, + num_pure_prefill_tokens) + + def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1, @@ -771,10 +836,119 @@ def split_decodes_and_prefills( return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) +def reorder_batch_to_split_decodes_prefills_and_chunks( + input_batch: "InputBatch", + scheduler_output: "SchedulerOutput", + decode_threshold: int = 1, +) -> bool: + """ + Reorders the batch to split into prefill, chunk_prefill + and decode requests; places all requests in the order of + [decodes:chunked_prefills:pure_prefills]. + + Returns: + True if the batch was modified, False otherwise. + """ + + # We assume most of the request is already in the order of what + # we desired since this function should only be opened after the + # `SchedulerConfig.split_prefill_from_chunk` is True. So we only + # need to spot all mismatched request and swap their positions + # for efficiency. + + decodes = [] + prefills = [] + chunk_prefills = [] + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if num_tokens <= decode_threshold: + decodes.append(i) + elif input_batch.num_computed_tokens_cpu[i] > 0: + chunk_prefills.append(i) + else: + prefills.append(i) + + num_decodes = len(decodes) + num_chunk_prefills = len(chunk_prefills) + # We define the reorder matrix here to help on the request reorder + # reorder_matrix[(i, j)] means the id the the requests that suppose + # to be in zone i but actually spot on zone j + # The decode, chunk prefill and pure prefill are separated into 3 + # different zone here, 0 for decode, 1 for chunk prefill and 2 for + # pure prefill + reorder_matrix: dict[tuple[int, int], deque[int]] = { + (i, j): deque() + for i in range(3) + for j in range(3) if i != j + } + + # collect mismatch + + def target_idx(idx): + if idx < num_decodes: + # decode as zone 0 + return 0 + elif idx < num_decodes + num_chunk_prefills: + # chunk prefill as zone 1 + return 1 + else: + # pure prefill as zone 2 + return 2 + + def fill_reorder_matrix(request_lists, reorder_sequence): + for idx, seq in enumerate(reorder_sequence): + request_list = request_lists[idx] + for req_idx in request_list: + req_target_id = target_idx(req_idx) + if seq != req_target_id: + reorder_matrix[(seq, req_target_id)].append(req_idx) + + def direct_zone_swap(i, j): + assert i != j + modified_batch = False + while reorder_matrix[(i, j)] and reorder_matrix[(j, i)]: + swap_req1 = reorder_matrix[(i, j)].pop() + swap_req2 = reorder_matrix[(j, i)].pop() + input_batch.swap_states(swap_req1, swap_req2) + modified_batch = True + + return modified_batch + + # in order 1,2,3, out order 3, 1, 2 + def indirect_zone_swap(zone_list): + assert len(zone_list) == 3 + modified_batch = False + while reorder_matrix[zone_list[0]] and reorder_matrix[ + zone_list[1]] and reorder_matrix[zone_list[2]]: + swap_req1 = reorder_matrix[zone_list[0]].pop() + swap_req2 = reorder_matrix[zone_list[1]].pop() + swap_req3 = reorder_matrix[zone_list[2]].pop() + + input_batch.swap_states(swap_req1, swap_req2) + input_batch.swap_states(swap_req2, swap_req3) + modified_batch = True + return modified_batch + + fill_reorder_matrix([decodes, chunk_prefills, prefills], [0, 1, 2]) + + modified_batch = False + # do directly swap for + modified_batch |= direct_zone_swap(0, 1) # decode <--> chunk prefill + modified_batch |= direct_zone_swap(0, 2) # decode <--> pure prefill + modified_batch |= direct_zone_swap(1, 2) # chunk prefill <--> pure prefill + + modified_batch |= indirect_zone_swap(((0, 1), (1, 2), (2, 0))) + modified_batch |= indirect_zone_swap(((2, 1), (0, 2), (1, 0))) + + return modified_batch + + def reorder_batch_to_split_decodes_and_prefills( input_batch: "InputBatch", scheduler_output: "SchedulerOutput", decode_threshold: int = 1, + reorder_append_prefills: bool = False, ) -> bool: """ Reorders the batch to split into prefill and decode requests; places all diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 10d8f6bbda5c..b21c8badf339 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -189,6 +189,8 @@ def schedule(self) -> SchedulerOutput: # and the "jump decoding" optimization in the future. scheduled_new_reqs: list[Request] = [] + new_reqs_for_pure_preill: list[Request] = [] + new_reqs_for_chunk_prefill: list[Request] = [] scheduled_resumed_reqs: list[Request] = [] scheduled_running_reqs: list[Request] = [] preempted_reqs: list[Request] = [] @@ -334,9 +336,10 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests = create_request_queue(self.policy) # Next, schedule the WAITING requests. + running_req_cnt = len(self.running) if not preempted_reqs: while self.waiting and token_budget > 0: - if len(self.running) == self.max_num_running_reqs: + if running_req_cnt == self.max_num_running_reqs: break request = self.waiting.peek_request() @@ -505,18 +508,18 @@ def schedule(self) -> SchedulerOutput: request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue + if num_computed_tokens > 0 and \ + self.scheduler_config.split_prefill_from_chunk: + new_reqs_for_chunk_prefill.append(request) + else: + new_reqs_for_pure_preill.append(request) + + running_req_cnt += 1 req_index += 1 - self.running.append(request) + if self.log_stats: request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp) - if request.status == RequestStatus.WAITING: - scheduled_new_reqs.append(request) - elif request.status == RequestStatus.PREEMPTED: - scheduled_resumed_reqs.append(request) - else: - raise RuntimeError( - f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) @@ -524,7 +527,7 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_blocks(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens - request.status = RequestStatus.RUNNING + # request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens # Count the number of prefix cached tokens. if request.num_cached_tokens < 0: @@ -538,6 +541,24 @@ def schedule(self) -> SchedulerOutput: self.encoder_cache_manager.allocate(request, i) encoder_compute_budget = new_encoder_compute_budget + # reorder the request during scheduling, put chunked prefill + # at the top of the scheduled_new_reqs to make sure the actual + # reorder in model runner happens as less as possible. + assert running_req_cnt == len(new_reqs_for_chunk_prefill) + \ + len(new_reqs_for_pure_preill) + len(self.running) + + new_reqs_for_chunk_prefill.extend(new_reqs_for_pure_preill) + for req in new_reqs_for_chunk_prefill: + self.running.append(req) + + if req.status == RequestStatus.WAITING: + scheduled_new_reqs.append(req) + elif req.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(req) + else: + raise RuntimeError(f"Invalid request status: {req.status}") + req.status = RequestStatus.RUNNING + # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.prepend_requests(skipped_waiting_requests) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 22a177dd7cc7..fd18628ad1ce 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -70,7 +70,8 @@ from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, - reorder_batch_to_split_decodes_and_prefills, split_attn_metadata) + reorder_batch_to_split_decodes_and_prefills, + reorder_batch_to_split_decodes_prefills_and_chunks, split_attn_metadata) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher # yapf conflicts with isort for this block # yapf: disable @@ -518,10 +519,16 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: if self.dcp_world_size > 1: assert self.reorder_batch_threshold == 1, \ "DCP not support reorder_batch_threshold > 1 now." - reorder_batch_to_split_decodes_and_prefills( - self.input_batch, - scheduler_output, - decode_threshold=self.reorder_batch_threshold) + if self.scheduler_config.split_prefill_from_chunk: + reorder_batch_to_split_decodes_prefills_and_chunks( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) + else: + reorder_batch_to_split_decodes_and_prefills( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) # Note: used for model runner override. def _init_device_properties(self) -> None: @@ -616,7 +623,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(req_state) - reqs_to_add.append(req_state) # Update the states of the running/resumed requests. @@ -665,6 +671,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. + reqs_to_add.append(req_state) continue @@ -703,6 +710,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: + self.input_batch.add_request(request) # Condense the batched states if there are gaps left by removed requests