Skip to content

Commit 53df680

Browse files
authored
Preliminary blackwell enablement (vllm-project#54)
* Pad flashmla_sparse to 128 on blackwell * adjust get_max_prefill_buffer_size * change comments
1 parent e744e06 commit 53df680

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

vllm/v1/attention/backends/mla/flashmla_sparse.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
2222
CommonAttentionMetadata)
2323
from vllm.v1.kv_cache_interface import AttentionSpec
24+
from vllm.platforms import current_platform
2425

2526
if TYPE_CHECKING:
2627
from vllm.model_executor.models.deepseek_v2 import Indexer
@@ -388,13 +389,15 @@ def _forward_bf16_kv(
388389
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
389390
-1, 1, kv_c_and_k_pe_cache.shape[-1])
390391

391-
# NOTE(Chen): kernel requires num_local_head to be a multiple of 64.
392-
if self.num_heads % 64 != 0:
393-
assert 64 % self.num_heads == 0
392+
# NOTE(Chen): kernel requires num_local_head to be a multiple of
393+
# 64 on hopper and 128 on blackwell
394+
padding = 128 if current_platform.is_device_capability(100) else 64
395+
if self.num_heads % padding != 0:
396+
assert padding % self.num_heads == 0
394397
logger.warning_once(
395-
"padding num_heads to 64 due to sparse attn kernel requirement"
398+
f"padding num_heads to {padding} due to sparse attn kernel requirement"
396399
)
397-
q_padded = q.new_empty((q.shape[0], 64, q.shape[2]))
400+
q_padded = q.new_empty((q.shape[0], padding, q.shape[2]))
398401
q_padded[:, :self.num_heads, :] = q
399402
q = q_padded
400403

vllm/v1/attention/backends/mla/indexer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,9 @@ def kv_spans_from_batches(start_seq_loc: torch.Tensor,
148148
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
149149
max_model_len = vllm_config.model_config.max_model_len
150150
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
151+
max_num_seq = vllm_config.scheduler_config.max_num_seqs
151152
# NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
152-
return max_model_len + max_num_batched_tokens
153+
return max_model_len * max_num_seq
153154

154155

155156
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):

0 commit comments

Comments
 (0)