diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index e0e95d06290d..6d30296be4e2 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -99,6 +99,7 @@ struct MlaSm100 { template typename T::Fmha::Arguments args_from_options( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -162,7 +163,10 @@ typename T::Fmha::Arguments args_from_options( stride_PT, page_count_total, page_size}, - {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, + {static_cast(out.data_ptr()), + stride_O, + static_cast(lse.defined() ? lse.data_ptr() : nullptr), + stride_LSE}, hw_info, // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will @@ -181,6 +185,7 @@ typename T::Fmha::Arguments args_from_options( template void runMla( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -192,7 +197,7 @@ void runMla( cudaStream_t stream) { using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); + auto arguments = args_from_options(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -214,6 +219,7 @@ void runMla( void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, @@ -234,13 +240,13 @@ void sm100_cutlass_mla_decode( DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { if (in_dtype == at::ScalarType::Half) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::BFloat16) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 85b6abef00b0..10b562a4c7ef 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -520,7 +520,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // SM100 CUTLASS MLA decode ops.def( - "sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," + "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope, " + "Tensor q_pe," " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," " Tensor page_table, Tensor workspace, float " "scale," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 92de39418054..1684e5b156c1 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1843,13 +1843,13 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, return out -def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, +def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor, + q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, scale: float, num_kv_splits: int) -> torch.Tensor: - torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe, + torch.ops._C.sm100_cutlass_mla_decode(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, scale, num_kv_splits) diff --git a/vllm/attention/ops/merge_attn_states.py b/vllm/attention/ops/merge_attn_states.py index 5cb1a47394cf..44f7ecdef4bf 100644 --- a/vllm/attention/ops/merge_attn_states.py +++ b/vllm/attention/ops/merge_attn_states.py @@ -41,3 +41,73 @@ def supported_headdim(o: torch.Tensor) -> bool: merge_attn_states) return merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse) + + +def merge_multi_attn_states(partials: torch.Tensor, + lse: torch.Tensor) -> torch.Tensor: + """Merge attention partials across a parallel dimension using LSE. + + Args: + partials: [tp, B, H_owned, D] + lse: [tp, B, H_owned] + + Returns: + merged: [B, H_owned, D] + """ + assert partials.dim() == 4 and lse.dim() == 3, ( + f"partials shape {partials.shape}, lse shape {lse.shape}") + tp, batch_size, heads_owned, dim = partials.shape + # [tp, B, H_owned] -> [B, H_owned] + max_lse, _ = torch.max(lse, dim=0) + # Avoid -inf producing NaNs + max_lse = torch.where(torch.isfinite(max_lse), max_lse, + torch.zeros_like(max_lse)) + + # Compute exp-corrected weights and normalize across tp + # [tp, B, H_owned] + weights = torch.exp(lse - max_lse.unsqueeze(0)) + denom = torch.clamp(weights.sum(dim=0, keepdim=False), min=1e-20) + weights = weights / denom + + # Apply weights to partials: broadcast weights to dim + # [tp, B, H_owned, D] + weighted = partials * weights.unsqueeze(-1) + merged = weighted.sum(dim=0) + return merged + + +def reduce_lse_over_tp(lse: torch.Tensor) -> torch.Tensor: + """Reduce per-rank LSE across TP via stable log-sum-exp. + + Args: + lse: [tp, B, H_owned] + + Returns: + reduced_lse: [B, H_owned] + """ + assert lse.dim() == 3 + tp_max, _ = torch.max(lse, dim=0) + tp_max = torch.where(torch.isfinite(tp_max), tp_max, + torch.zeros_like(tp_max)) + weights = torch.exp(lse - tp_max.unsqueeze(0)) + denom = torch.clamp(weights.sum(dim=0, keepdim=False), min=1e-20) + return torch.log(denom) + tp_max + + +def merge_multi_attn_states_with_lse( + partials: torch.Tensor, + lse: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Fused helper that returns merged outputs and reduced LSE. + + Args: + partials: [tp, B, H_owned, D] + lse: [tp, B, H_owned] + + Returns: + (merged, reduced_lse): + merged: [B, H_owned, D] + reduced_lse: [B, H_owned] + """ + merged = merge_multi_attn_states(partials, lse) + reduced = reduce_lse_over_tp(lse) + return merged, reduced diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 7efab23f144a..4d5f217ed6d3 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -464,6 +464,13 @@ class ModelConfig: - "transformers" will use the Transformers model implementation.""" override_attention_dtype: Optional[str] = None """Override dtype for attention""" + enable_mla_sharded_kv: bool = False + """Enable MLA sharded KV mode for tensor parallelism. + + When enabled with tensor parallelism (>1), MLA decode will gather query + tensors across TP ranks to form full queries per rank. Without this flag, + MLA with TP>1 is disallowed to avoid silent fallbacks. + """ def compute_hash(self) -> str: """ @@ -490,6 +497,7 @@ def compute_hash(self) -> str: factors.append(self.generation_config) factors.append(self.model_impl) factors.append(self.override_generation_config) + factors.append(self.enable_mla_sharded_kv) factors.append(self.rope_scaling) factors.append(self.rope_theta) # hf_config can control how the model looks! diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0b3993ca0275..775cede1899f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -400,6 +400,23 @@ def all_gatherv(self, raise ValueError("No device communicator found") return self.device_communicator.all_gatherv(input_, dim, sizes) + def all_to_all(self, input_: torch.Tensor, dim: int = 0) -> torch.Tensor: + """All-to-all over the device group, splitting along dim equally. + + Note: This is a simple wrapper for torch.distributed.all_to_all_single + with equal splits across ranks. + """ + world_size = self.world_size + if world_size == 1: + return input_ + if dim < 0: + dim += input_.dim() + x = input_.movedim(dim, 0).contiguous() + assert x.shape[0] % world_size == 0 + out = torch.empty_like(x) + torch.distributed.all_to_all_single(out, x, group=self.device_group) + return out.movedim(0, dim).contiguous() + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4d4ce4c78e9f..0e5ac4b88c2a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -547,6 +547,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["model_impl"]) model_group.add_argument("--override-attention-dtype", **model_kwargs["override_attention_dtype"]) + model_group.add_argument("--enable-mla-sharded-kv", + **model_kwargs["enable_mla_sharded_kv"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index badff67656c2..31811664267e 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -200,9 +200,12 @@ AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import get_mla_dims -from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.ops.merge_attn_states import (merge_attn_states, + merge_multi_attn_states, + reduce_lse_over_tp) from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, @@ -368,6 +371,9 @@ class MLACommonMetadata(Generic[D]): FlashInferPrefillMetadata, CudnnPrefillMetadata]] = None + # Whether KV sharding is enabled for this batch/rank + sharded_kv: bool = False + def __post_init__(self): if self.head_dim is not None: MLACommonBackend.validate_head_size(self.head_dim) @@ -414,6 +420,9 @@ def __init__(self, self.device = device scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.enable_mla_sharded_kv = vllm_config.model_config\ + .enable_mla_sharded_kv cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled @@ -422,6 +431,13 @@ def __init__(self, self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() + self.num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.tp_rank = get_tensor_model_parallel_rank() + self.num_q_heads_decode = self.num_q_heads \ + if self.enable_mla_sharded_kv else self.num_q_heads * self.tp_size + # Dont try to access the runner on AMD if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size @@ -593,10 +609,13 @@ def build(self, device = self.device block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + # Optional sharding of KV metadata for MLA + enable_sharded_kv = self.enable_mla_sharded_kv query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -613,15 +632,54 @@ def build(self, if num_prefills > 0: reqs_start = num_decodes # prefill_start + # Base context lens per sequence (unsharded) context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() + dummy_context = False + + if enable_sharded_kv and max_context_len_cpu > 0: + B = self.kv_cache_spec.block_size + device = block_table_tensor.device + bt_pre = block_table_tensor[reqs_start:, ...] + # context lengths on device + context_lens = context_lens_cpu.to(device, non_blocking=True) + # Per-request used blocks and mask across existing table width + blocks_per_req = (context_lens + + (B - 1)) // B # [num_prefills] + max_blocks_per_req = (+B - 1) // B + arange_cols = torch.arange(max_blocks_per_req, device=device) + used_blocks_mask = arange_cols.unsqueeze( + 0) < blocks_per_req.view(-1, 1) + + bt_pre = block_table_tensor[reqs_start:, :max_blocks_per_req] + bt_pre.masked_fill_(~used_blocks_mask, -1) + # Localize block table to this rank + owned = (bt_pre % self.tp_size == self.tp_rank) & (bt_pre >= 0) + bt_pre[:] = torch.where(owned, bt_pre // self.tp_size, -1) + + # Per-row last block ownership (handle zero-block requests + # safely) + last_idx = torch.clamp(blocks_per_req - 1, min=0) + last_owned = owned[:, last_idx].flatten().cpu() + partial_last = (context_lens_cpu % B) != 0 + + # Compute effective per-rank context lengths + last_page_len = torch.where( + context_lens_cpu > 0 & partial_last & last_owned, + context_lens_cpu % B + 1, B) + owned_counts = owned.sum(dim=1).cpu() + context_lens_cpu = (owned_counts - 1) * B + last_page_len + # Recompute: + max_context_len_cpu = context_lens_cpu.max().item() + dummy_context = (max_context_len_cpu == 0) + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None if self.chunked_prefill_enabled and num_prefills > 0 \ - and max_context_len_cpu > 0: + and (max_context_len_cpu > 0 or dummy_context): # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to # understand the following code @@ -641,7 +699,8 @@ def build(self, self.page_size) assert max_context_chunk > 0 - num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) \ + if not dummy_context else 1 # if `max_context_chunk = 256`, `num_chunks = 3`, and # `num_prefills_with_context = 4`, create a tensor that looks @@ -702,9 +761,38 @@ def build(self, decode_metadata = None if num_decodes > 0: + decode_block_table = block_table_tensor[:num_decodes, ...] + decode_seq_lens_cpu = seq_lens_cpu[:num_decodes] + decode_seq_lens = seq_lens[:num_decodes] + if enable_sharded_kv: + max_blocks_per_req = (decode_seq_lens_cpu.max().item() + B - + 1) // B + decode_block_table = decode_block_table[:, :max_blocks_per_req] + blocks_per_req = (decode_seq_lens + B - 1) // B + used_blocks_mask = torch.arange( + max_blocks_per_req, device=decode_block_table.device + ) < blocks_per_req.unsqueeze(1) + decode_block_table = decode_block_table[~used_blocks_mask] = -1 + # Localize block table and compute per-rank effective lengths + owned = (decode_block_table % self.tp_size == self.tp_rank) & \ + (decode_block_table != -1) + decode_block_table = decode_block_table[owned] + # Convert to local physical indices + decode_block_table = decode_block_table // self.tp_size + + owned_blocks_per_req = blocks_per_req.sum(dim=1) + owns_last_block = owned[:, blocks_per_req - 1] + partial_last = (decode_seq_lens % B) != 0 + + # Adjust seq_lens to only account for owned blocks + decode_seq_lens = owned_blocks_per_req * B \ + - torch.where(owns_last_block & partial_last, + B - decode_seq_lens % B, + 0) + decode_metadata = self._build_decode( - block_table_tensor=block_table_tensor[:num_decodes, ...], - seq_lens=seq_lens[:num_decodes], + block_table_tensor=decode_block_table, + seq_lens=decode_seq_lens, ) attn_metadata = self.metadata_cls( @@ -720,6 +808,7 @@ def build(self, num_prefills=num_prefills, prefill=prefill_metadata, decode=decode_metadata, + sharded_kv=enable_sharded_kv, ) if self._use_fi_prefill and num_prefills > 0: @@ -763,7 +852,11 @@ def __init__( if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") - self.num_heads = num_heads + self.tp_size = get_tp_group().world_size + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_group = get_tp_group() + self.num_local_heads = num_heads + self.num_global_heads = num_heads * self.tp_size self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads @@ -1102,14 +1195,40 @@ def _forward_prefill( if has_context: suffix_output, suffix_lse = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata) + use_sharded = attn_metadata.sharded_kv + if use_sharded: + # All-gather Q across TP so each rank has all heads + from vllm.distributed import get_tp_group + tp_group = get_tp_group() + q_all = tp_group.all_gather(q, dim=1) + # Local compute using only local KV + # (metadata/block_table are sharded) + context_output_local, context_lse_local = \ + self._compute_prefill_context(q_all, kv_c_and_k_pe_cache, + attn_metadata) + # Reshape to [tp, B, heads_owned, D] + B = context_output_local.shape[0] + heads_owned = self.num_heads + D = context_output_local.shape[2] + parts = context_output_local.view(B, self.num_heads, D) \ + .view(B, -1, heads_owned, D).movedim(1, 0) + lse_parts = context_lse_local.view(B, self.num_heads) \ + .view(B, -1, heads_owned).movedim(1, 0) + parts_exchanged = tp_group.all_to_all(parts, dim=0) + lse_exchanged = tp_group.all_to_all(lse_parts, dim=0) + context_output_owned = merge_multi_attn_states( + parts_exchanged, lse_exchanged) + context_lse_owned = reduce_lse_over_tp(lse_exchanged) + else: + context_output_owned, context_lse_owned = \ + self._compute_prefill_context(q, kv_c_and_k_pe_cache, + attn_metadata) output = torch.empty_like(suffix_output) merge_attn_states( output=output, - prefix_output=context_output, - prefix_lse=context_lse, + prefix_output=context_output_owned, + prefix_lse=context_lse_owned, suffix_output=suffix_output, suffix_lse=suffix_lse, ) @@ -1120,14 +1239,60 @@ def _forward_prefill( return output.flatten(start_dim=-2) - @abstractmethod - def _forward_decode( + def _forward_decode_common( self, ql_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, ) -> torch.Tensor: + # Multi-rank Strategy B for decode; delegates local decode to + # subclass via _forward_decode_local which should return (o, lse) + tp = self.parallel_config.tensor_parallel_size + use_sharded = bool(attn_metadata.sharded_kv and tp > 1) + # Form full-head Q for decode + if not use_sharded: + o, _ = self._forward_decode_local(ql_nope, + q_pe, + kv_c_and_k_pe_cache, + attn_metadata, + return_lse=False) + return self._v_up_proj(o) + + # All-gather Q across TP so each rank has all heads for local decode + ql_nope_all = self.tp_group.all_gather(ql_nope, dim=1) + q_pe_all = self.tp_group.all_gather(q_pe, dim=1) + + # Each rank computes local partial for all heads, then exchange+merge + local_out, local_lse = self._forward_decode_local(ql_nope_all, + q_pe_all, + kv_c_and_k_pe_cache, + attn_metadata, + return_lse=True) + if local_lse is None: + raise ValueError("LSE is required for shareded-mla decode") + B = local_out.shape[0] + D = local_out.shape[2] + parts = local_out.view(B, self.num_global_heads, D)\ + .view(B, tp, self.num_local_heads, D) + lse_parts = local_lse.view(B, self.num_global_heads)\ + .view(B, tp, self.num_local_heads) + parts_exchanged = self.tp_group.all_to_all(parts, dim=1) + lse_exchanged = self.tp_group.all_to_all(lse_parts, dim=1) + + owned = merge_multi_attn_states(parts_exchanged.movedim(1, 0), + lse_exchanged.movedim(1, 0)) + return self._v_up_proj(owned) + + @abstractmethod + def _forward_decode_local( + self, + ql_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: M, + return_lse: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: raise NotImplementedError def forward( @@ -1179,11 +1344,14 @@ def forward( # write the latent and rope to kv cache if kv_cache.numel() > 0: + # slot_mapping already pre-sharded in builder when enabled. + slot_map = attn_metadata.slot_mapping.flatten() + ops.concat_and_cache_mla( k_c_normed, k_pe.squeeze(1), kv_cache, - attn_metadata.slot_mapping.flatten(), + slot_map, kv_cache_dtype=self.kv_cache_dtype, scale=layer._k_scale, ) @@ -1204,7 +1372,7 @@ def forward( # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) - output[:num_decode_tokens] = self._forward_decode( + output[:num_decode_tokens] = self._forward_decode_common( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) return output_padded diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index b23a8f0a5e87..658cfa937215 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -130,7 +130,7 @@ def _sm100_cutlass_mla_decode( workspace: torch.Tensor, sm_scale: float, num_kv_splits: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: assert (q_nope.ndim == 3 ), f"q_nope must be a 3D tensor, but got {q_nope.ndim}" assert ( @@ -184,9 +184,13 @@ def _sm100_cutlass_mla_decode( ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." out = q_nope.new_empty((B_q, MAX_HEADS, D_latent)) + lse = torch.empty((B_q, MAX_HEADS), + dtype=torch.float32, + device=q_nope.device) ops.sm100_cutlass_mla_decode( out, + lse, q_nope, q_pe, kv_c_and_k_pe_cache, @@ -196,7 +200,7 @@ def _sm100_cutlass_mla_decode( sm_scale, num_kv_splits, ) - return out[:, :H].contiguous() + return out[:, :H].contiguous(), lse def _sm100_forward_decode( self, @@ -206,7 +210,6 @@ def _sm100_forward_decode( attn_metadata: MLACommonMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 - assert attn_metadata.decode is not None if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Cutlass MLA not yet supported") @@ -220,13 +223,13 @@ def _sm100_forward_decode( q_nope = q_nope.clone() q_pe = q_pe.clone() - o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata.decode.seq_lens, - attn_metadata.decode.block_table, - self._workspace.get_buf(), - self.scale, self._num_kv_splits) + assert attn_metadata.decode is not None + o, lse = self._sm100_cutlass_mla_decode( + q_nope, q_pe, kv_c_and_k_pe_cache, + attn_metadata.decode.seq_lens, attn_metadata.decode.block_table, + self._workspace.get_buf(), self.scale, self._num_kv_splits) - return self._v_up_proj(o) + return o, lse # TODO: Currently we leave it here only for backup in case something is # wrong with the new SM100 CUTLASS MLA kernel @@ -260,13 +263,14 @@ def _old_forward_decode( return self._v_up_proj(o) - def _forward_decode( + def _forward_decode_local( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: + return_lse: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if self._use_old_cutlass_mla: # TODO: Remove the old cutlass MLA kernel after more extensive # testing diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2b0f52cf80bf..4d4425c9ffa2 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -64,8 +64,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], FlashMLAMetadata) self.compilation_config = vllm_config.compilation_config - self.num_q_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None @@ -91,7 +89,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens, - self.num_q_heads, + self.num_q_heads_decode, 1, # MQA for the decode path ) @@ -167,20 +165,23 @@ def __init__( raise NotImplementedError( "FlashMLA V1 with FP8 KV cache not yet supported") - def _forward_decode( + def _forward_decode_local( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, - ) -> torch.Tensor: + return_lse: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - q = torch.cat([q_nope, q_pe], dim=-1)\ - .unsqueeze(1) # Add seqlen dim of 1 (decode) + # Form full-head query for decode + q = torch.cat([q_nope, q_pe], dim=-1).unsqueeze(1) - o, _ = flash_mla_with_kvcache( + # Compute local partial + # (all heads, local KV only due to sharded metadata) + o, lse = flash_mla_with_kvcache( q=q, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=attn_metadata.decode.block_table, @@ -192,5 +193,6 @@ def _forward_decode( softmax_scale=self.scale, causal=True, ) - - return self._v_up_proj(o) + o = o.squeeze(1) + lse = lse.squeeze(-1) + return o, (lse if return_lse else None) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 626aa35a770c..2976ebb2980c 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -817,14 +817,23 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, for layer_name in kv_cache_spec ] + # If MLA sharded KV is enabled with tensor-parallelism, reserve the global + # id space as if blocks were replicated across TP ranks. We still only + # allocate local blocks physically, but report logical capacity as + # num_blocks * tp to the rest of the system. + tp = vllm_config.parallel_config.tensor_parallel_size + logical_num_blocks = num_blocks * tp if \ + getattr(vllm_config.model_config, "enable_mla_sharded_kv", False) \ + and tp > 1 else num_blocks + kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, + num_blocks=logical_num_blocks, kv_cache_tensors=kv_cache_tensors, kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, grouped_layer_names), ) - num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens = logical_num_blocks * vllm_config.cache_config.block_size num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" @@ -978,8 +987,13 @@ def _get_kv_cache_config_uniform_page_size( kv_cache_tensors.append( KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by)) + tp = vllm_config.parallel_config.tensor_parallel_size + logical_num_blocks = num_blocks * tp if \ + getattr(vllm_config.model_config, "enable_mla_sharded_kv", False) \ + and tp > 1 else num_blocks + kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, + num_blocks=logical_num_blocks, kv_cache_tensors=kv_cache_tensors, kv_cache_groups=kv_cache_groups, ) @@ -988,7 +1002,7 @@ def _get_kv_cache_config_uniform_page_size( [group.kv_cache_spec.block_size for group in kv_cache_groups]) # Print the KV cache size and maximum concurrency. - num_tokens = num_blocks // len(grouped_layers) * min_block_size + num_tokens = logical_num_blocks // len(grouped_layers) * min_block_size num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"