Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ModelRunnerInputBase,
ModelRunnerInputBuilderBase)

from vllm.multistream.base import MSAttentionMetadataSplitConfig

class AttentionType:
"""
Expand Down Expand Up @@ -156,6 +157,12 @@ def asdict_zerocopy(self,
for field in fields(self) if field.name not in skip_fields
}

def split_metadata_for_multistream(self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> List["AttentionMetadata"]:
raise NotImplementedError



T = TypeVar("T", bound=AttentionMetadata)

Expand Down
42 changes: 40 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
is_all_encoder_attn_metadata_set, is_block_tables_empty, common_split_metadata_for_multistream)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
Expand All @@ -30,6 +30,8 @@
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)

from vllm.multistream.base import MSAttentionMetadataSplitConfig

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
Expand Down Expand Up @@ -109,6 +111,9 @@ class FlashAttentionMetadata(AttentionMetadata):
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The query length per sequence. Query length means the
# new tokens
query_lens: Optional[List[int]]
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
Expand Down Expand Up @@ -218,6 +223,8 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
query_lens = (None if self.query_lens is None else
self.query_lens[:self.num_prefills])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
Expand All @@ -237,6 +244,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -282,6 +290,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
query_lens=None,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
Expand Down Expand Up @@ -347,6 +356,8 @@ def advance_step(self,
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )

assert self.query_lens is not None
assert len(self.query_lens) <= num_seqs
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
Expand All @@ -368,6 +379,7 @@ def advance_step(self,
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.query_lens[i] += 1
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)

Expand All @@ -380,7 +392,33 @@ def advance_step(self,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)

def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> List["FlashAttentionMetadata"]:
"""Split metadata for multi-stream with FlashAttentionBackend"""
return common_split_metadata_for_multistream(
ms_split_config=ms_split_config,
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping,
query_lens=self.query_lens,
seq_lens=self.seq_lens,
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps, # TODO maybe error
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens_tensor=self.seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=self.query_start_loc,
seq_start_loc=self.seq_start_loc,
context_lens_tensor=self.context_lens_tensor,
block_tables=self.block_tables,
use_cuda_graph=self.use_cuda_graph,
attn_metadata=self,
_metadata_cls=FlashAttentionMetadata,
)

class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
Expand Down
200 changes: 200 additions & 0 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from contextlib import contextmanager
from itertools import accumulate
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union

import numpy as np
Expand All @@ -15,6 +16,9 @@
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

from vllm.multistream.base import MSAttentionMetadataSplitConfig


logger = init_logger(__name__)

if TYPE_CHECKING:
Expand Down Expand Up @@ -272,6 +276,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
Expand Down Expand Up @@ -327,6 +332,7 @@ def graph_capture_get_metadata_for_batch(
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
query_lens=None,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
Expand Down Expand Up @@ -583,3 +589,197 @@ def get_num_prefill_decode_query_kv_tokens(

return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens)


def common_split_metadata_for_multistream(
ms_split_config: MSAttentionMetadataSplitConfig,
num_prefills: int,
num_prefill_tokens: int,
num_decode_tokens: int,
slot_mapping: torch.Tensor,
query_lens: List[int],
seq_lens: List[int],
multi_modal_placeholder_index_maps: Dict[str, MultiModalPlaceholderMap.IndexMap],
enable_kv_scales_calculation: bool,
seq_lens_tensor: torch.Tensor,
max_query_len: int,
max_prefill_seq_len: int,
max_decode_seq_len: int,
query_start_loc: torch.Tensor,
seq_start_loc: torch.Tensor,
context_lens_tensor: torch.Tensor,
block_tables: torch.Tensor,
use_cuda_graph: bool,
attn_metadata: "AttentionMetadata",
_metadata_cls: Type[TAttentionMetadata],
) -> List[Any]:
assert 0 < ms_split_config.num_micro_batches < 3
assert ms_split_config.enable_request_split, "Only support causal attention yet."
# not support multi-stream for decode-only phase for now
if num_prefill_tokens == 0:
return [attn_metadata]

# get batches info
total_tokens = num_prefill_tokens + num_decode_tokens
if (total_tokens < ms_split_config.min_total_tokens_to_split or
num_prefill_tokens < ms_split_config.min_prefill_tokens_to_split):
return [attn_metadata]
mean_token_num = total_tokens // ms_split_config.num_micro_batches
token_imbalance_ratio = ms_split_config.imbalance_ratio

query_start_loc_cpu = np.zeros(shape=(len(query_lens) + 1,), dtype=int)
np.cumsum(query_lens, out=query_start_loc_cpu[1:])

# find a batch to split
split_batch_index = 0
need_chunk = False
for i in range(len(query_start_loc_cpu) - 1):
if query_start_loc_cpu[i] <= mean_token_num <= query_start_loc_cpu[i + 1] or i > num_prefills - 1:
split_batch_index = i
break

if split_batch_index > num_prefills - 1:
split_batch_index = split_batch_index - 1
need_chunk = False
else:
if abs(query_start_loc_cpu[split_batch_index] - mean_token_num) < total_tokens * token_imbalance_ratio:
split_batch_index = split_batch_index - 1
elif abs(query_start_loc_cpu[split_batch_index + 1] - mean_token_num) < total_tokens * token_imbalance_ratio:
split_batch_index = split_batch_index
else:
split_batch_index = split_batch_index
need_chunk = True

if not need_chunk:
# pre
num_prefills_pre = split_batch_index + 1
slot_mapping_pre = slot_mapping[:query_start_loc_cpu[split_batch_index + 1]]
num_prefills_tokens_pre = query_start_loc_cpu[split_batch_index + 1]
num_decode_tokens_pre = 0
query_lens_pre = query_lens[:split_batch_index + 1]
seq_lens_pre = seq_lens[:split_batch_index + 1]
seq_lens_tensor_pre = seq_lens_tensor[:split_batch_index + 1]
max_query_len_pre = max(query_lens[:split_batch_index + 1])
max_prefill_seq_len_pre = max(seq_lens_pre, default=0)
max_decode_query_len_pre = 1
max_decode_seq_len_pre = 0
query_start_loc_pre = deepcopy(query_start_loc[:split_batch_index + 2])
seq_start_loc_pre = deepcopy(seq_start_loc[:split_batch_index + 2])
context_lens_tensor_pre = context_lens_tensor[:split_batch_index + 1]
block_tables_pre = block_tables[:split_batch_index + 1]
use_cuda_graph_pre = use_cuda_graph
# post
num_prefills_post = num_prefills - num_prefills_pre
slot_mapping_post = slot_mapping[query_start_loc_cpu[split_batch_index + 1]:]
num_prefills_token_post = num_prefill_tokens - num_prefills_tokens_pre
num_decode_token_post = num_decode_tokens
seq_lens_post = seq_lens[split_batch_index + 1:]
seq_lens_tensor_post = seq_lens_tensor[split_batch_index + 1:]
query_lens_post = query_lens[split_batch_index + 1:]
max_query_len_post = max(query_lens_post)
max_prefill_seq_len_post = max(seq_lens_post[:num_prefills_post], default=0)
decode_query_lens = query_lens_post[num_prefills_post:]
if len(decode_query_lens) > 0:
max_decode_query_len_post = max(decode_query_lens, default=0)
else:
max_decode_query_len_post = 1
max_decode_seq_len_post = max_decode_seq_len
query_start_loc_post = deepcopy(query_start_loc[split_batch_index + 1:]) - \
query_start_loc[split_batch_index + 1]
seq_start_loc_post = deepcopy(seq_start_loc[split_batch_index + 1:]) - \
seq_start_loc[split_batch_index + 1]
context_lens_tensor_post = context_lens_tensor[split_batch_index + 1:]
block_tables_post = block_tables[split_batch_index + 1:]
use_cuda_graph_post = use_cuda_graph
else: # split one prefill request
split_tokens_pre = mean_token_num - query_start_loc_cpu[split_batch_index]
split_tokens_post = query_start_loc_cpu[split_batch_index + 1] - mean_token_num
# pre
num_prefills_pre = split_batch_index + 1
slot_mapping_pre = slot_mapping[:mean_token_num]
num_prefills_tokens_pre = mean_token_num
num_decode_tokens_pre = 0
seq_lens_pre = deepcopy(seq_lens[:split_batch_index + 1]) # deepcopy
seq_lens_pre[-1] = seq_lens_pre[-1] - split_tokens_post
seq_lens_tensor_pre = deepcopy(seq_lens_tensor[:split_batch_index + 1])
seq_lens_tensor_pre[-1] = seq_lens_tensor_pre[-1] - split_tokens_post
query_lens_pre = query_lens[:split_batch_index] + [split_tokens_pre]
max_query_len_pre = max(query_lens_pre)
max_prefill_seq_len_pre = max(seq_lens_pre, default=0)
max_decode_query_len_pre = 1
max_decode_seq_len_pre = 0
query_start_loc_pre = deepcopy(query_start_loc[:split_batch_index + 2])
query_start_loc_pre[-1] = query_start_loc_pre[-1] - split_tokens_post
seq_start_loc_pre = deepcopy(seq_start_loc[:split_batch_index + 2])
seq_start_loc_pre[-1] = seq_start_loc_pre[-1] - split_tokens_post
context_lens_tensor_pre = context_lens_tensor[:split_batch_index + 1]
block_tables_pre = block_tables[:split_batch_index + 1]
use_cuda_graph_pre = use_cuda_graph
# post
num_prefills_post = num_prefills - num_prefills_pre + 1
slot_mapping_post = slot_mapping[mean_token_num:]
num_prefills_token_post = num_prefill_tokens - num_prefills_tokens_pre
num_decode_token_post = num_decode_tokens
seq_lens_post = seq_lens[split_batch_index:]
seq_lens_tensor_post = seq_lens_tensor[split_batch_index:]
query_lens_post = [split_tokens_post] + query_lens[split_batch_index + 1:]
max_query_len_post = max(query_lens_post)
max_prefill_seq_len_post = max(seq_lens_post[:num_prefills_post], default=0)
decode_query_lens = query_lens_post[num_prefills_post:]
if len(decode_query_lens) > 0:
max_decode_query_len_post = max(decode_query_lens, default=0)
else:
max_decode_query_len_post = 1
max_decode_seq_len_post = max_decode_seq_len
query_start_loc_post = deepcopy(query_start_loc[split_batch_index:]) - \
query_start_loc[split_batch_index]
query_start_loc_post[1:] = query_start_loc_post[1:] - split_tokens_pre
seq_start_loc_post = deepcopy(seq_start_loc[split_batch_index:]) - \
seq_start_loc[split_batch_index]
context_lens_tensor_post = deepcopy(context_lens_tensor[split_batch_index:])
context_lens_tensor_post[0] = context_lens_tensor_post[0] + split_tokens_pre
block_tables_post = block_tables[split_batch_index:]
use_cuda_graph_post = use_cuda_graph

attention_metadata_pre = _metadata_cls(
num_prefills=num_prefills_pre,
slot_mapping=slot_mapping_pre,
num_prefill_tokens=num_prefills_tokens_pre,
num_decode_tokens=num_decode_tokens_pre,
query_lens=query_lens_pre,
seq_lens=seq_lens_pre,
multi_modal_placeholder_index_maps=multi_modal_placeholder_index_maps, # TODO maybe error
enable_kv_scales_calculation=enable_kv_scales_calculation,
seq_lens_tensor=seq_lens_tensor_pre,
max_query_len=max_query_len_pre,
max_decode_query_len=max_decode_query_len_pre,
max_prefill_seq_len=max_prefill_seq_len_pre,
max_decode_seq_len=max_decode_seq_len_pre,
query_start_loc=query_start_loc_pre,
seq_start_loc=seq_start_loc_pre,
context_lens_tensor=context_lens_tensor_pre,
block_tables=block_tables_pre,
use_cuda_graph=use_cuda_graph_pre,
)

attention_metadata_post = _metadata_cls(
num_prefills=num_prefills_post,
slot_mapping=slot_mapping_post,
num_prefill_tokens=num_prefills_token_post,
num_decode_tokens=num_decode_token_post,
query_lens=query_lens_post,
seq_lens=seq_lens_post,
multi_modal_placeholder_index_maps=multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=enable_kv_scales_calculation,
seq_lens_tensor=seq_lens_tensor_post,
max_query_len=max_query_len_post,
max_decode_query_len=max_decode_query_len_post,
max_prefill_seq_len=max_prefill_seq_len_post,
max_decode_seq_len=max_decode_seq_len_post,
query_start_loc=query_start_loc_post,
seq_start_loc=seq_start_loc_post,
context_lens_tensor=context_lens_tensor_post,
block_tables=block_tables_post,
use_cuda_graph=use_cuda_graph_post,
)
return [attention_metadata_pre, attention_metadata_post]
11 changes: 11 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,14 @@ def global_force_attn_backend_context_manager(
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)


def verify_attn_backend(
attention_backend: Type[AttentionBackend] = None,
enable_multi_stream: bool = False,
):
if enable_multi_stream:
from vllm.attention.backends.flash_attn import FlashAttentionBackend
assert (attention_backend.get_name() == FlashAttentionBackend.get_name()), \
(f"enable_multi_stream only supports FlashAttentionBackend, "
f"now backend is {attention_backend.get_name()}")
Loading