diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 2ec771a64557..898f71f31b51 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1076,7 +1076,14 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, q, k, maybe_padded_v, - **kwargs, + None, # output + kwargs["cu_seqlens_q"], + kwargs["cu_seqlens_k"], + kwargs["max_seqlen_q"], + kwargs["max_seqlen_k"], + kwargs["causal"], + softmax_scale, + None, # bias ) if is_vllm_fa: attn_out = self.flash_attn_varlen_func(