File tree Expand file tree Collapse file tree 2 files changed +10
-6
lines changed
vllm/v1/attention/backends/mla Expand file tree Collapse file tree 2 files changed +10
-6
lines changed Original file line number Diff line number Diff line change 2121from vllm .v1 .attention .backends .utils import (AttentionMetadataBuilder ,
2222 CommonAttentionMetadata )
2323from vllm .v1 .kv_cache_interface import AttentionSpec
24+ from vllm .platforms import current_platform
2425
2526if TYPE_CHECKING :
2627 from vllm .model_executor .models .deepseek_v2 import Indexer
@@ -388,13 +389,15 @@ def _forward_bf16_kv(
388389 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache .view (
389390 - 1 , 1 , kv_c_and_k_pe_cache .shape [- 1 ])
390391
391- # NOTE(Chen): kernel requires num_local_head to be a multiple of 64.
392- if self .num_heads % 64 != 0 :
393- assert 64 % self .num_heads == 0
392+ # NOTE(Chen): kernel requires num_local_head to be a multiple of
393+ # 64 on hopper and 128 on blackwell
394+ padding = 128 if current_platform .is_device_capability (100 ) else 64
395+ if self .num_heads % padding != 0 :
396+ assert padding % self .num_heads == 0
394397 logger .warning_once (
395- "padding num_heads to 64 due to sparse attn kernel requirement"
398+ f "padding num_heads to { padding } due to sparse attn kernel requirement"
396399 )
397- q_padded = q .new_empty ((q .shape [0 ], 64 , q .shape [2 ]))
400+ q_padded = q .new_empty ((q .shape [0 ], padding , q .shape [2 ]))
398401 q_padded [:, :self .num_heads , :] = q
399402 q = q_padded
400403
Original file line number Diff line number Diff line change @@ -148,8 +148,9 @@ def kv_spans_from_batches(start_seq_loc: torch.Tensor,
148148def get_max_prefill_buffer_size (vllm_config : VllmConfig ):
149149 max_model_len = vllm_config .model_config .max_model_len
150150 max_num_batched_tokens = vllm_config .scheduler_config .max_num_batched_tokens
151+ max_num_seq = vllm_config .scheduler_config .max_num_seqs
151152 # NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
152- return max_model_len + max_num_batched_tokens
153+ return max_model_len * max_num_seq
153154
154155
155156class DeepseekV32IndexerMetadataBuilder (AttentionMetadataBuilder ):
You can’t perform that action at this time.
0 commit comments