From 83c39e554c4380b3ecaa64119625ddea6822f231 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 10:08:24 +0300 Subject: [PATCH 01/44] feat: Added MambaStateShapeCalculator Signed-off-by: asafg --- .../layers/mamba/mamba_mixer2.py | 6 +- .../layers/mamba/mamba_utils.py | 109 ++++++++++-------- vllm/model_executor/models/bamba.py | 5 +- vllm/model_executor/models/falcon_h1.py | 5 +- .../model_executor/models/granitemoehybrid.py | 5 +- vllm/model_executor/models/mamba2.py | 5 +- vllm/model_executor/models/nemotron_h.py | 5 +- vllm/model_executor/models/zamba2.py | 5 +- 8 files changed, 80 insertions(+), 65 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 5ac9a7f9ab3e..9a5c22a3a185 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, update_metadata) from vllm.model_executor.layers.mamba.mamba_utils import ( - extra_groups_for_head_shards, get_mamba_state_shape) + MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated @@ -278,7 +278,7 @@ def __init__( # - for TP we shard conv_dim by sharding on n_groups, # - but if n_groups cannot divide tp_size, we need to # extend some extra groups - self.n_groups = n_groups + extra_groups_for_head_shards( + self.n_groups = n_groups + MambaStateShapeCalculator.extra_groups_for_head_shards( n_groups, self.tp_size) self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size @@ -732,7 +732,7 @@ def forward_cuda( output[:num_actual_tokens], _ = self.out_proj(hidden_states) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=self.intermediate_size, tp_world_size=get_tensor_model_parallel_world_size(), n_groups=self.n_groups, diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 99a582066c0d..818c503fff4e 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -3,53 +3,62 @@ from vllm.distributed import divide -def extra_groups_for_head_shards(ngroups: int, tp_size: int): - """Compute the increase in group numbers to account for - replication in order to accompany the head shards.""" - - # in the case ngoups % tp_size == 0, this will be zero - if ngroups % tp_size == 0: - return 0 - - # for n_groups == 1, this is exactly tp_size - n_groups - return tp_size - ngroups - - -def get_mamba_state_shape( - intermediate_size: int, - tp_world_size: int, - n_groups: int, - num_heads: int, - head_dim: int, - state_size: int, - conv_kernel: int, - use_v1: bool = True, -) -> tuple[tuple[int, int], tuple[int, int, int]]: - """ Get the shape of mamba state.""" - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (n_groups + - extra_groups_for_head_shards(n_groups, tp_world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + 2 * n_groups * state_size) - # contiguous along 'dim' axis - conv_state_shape = ( - conv_kernel - 1, - divide(conv_dim, tp_world_size), - ) - - if not use_v1: - conv_state_shape = (conv_state_shape[1], conv_state_shape[0]) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) - temporal_state_shape = ( - divide(num_heads, tp_world_size), - head_dim, - state_size, - ) - - return conv_state_shape, temporal_state_shape +class MambaStateShapeCalculator: + @classmethod # type: ignore + def mamba1_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + state_size: int, + conv_kernel: int, + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int]]: + conv_state_shape = (conv_kernel - 1, divide(intermediate_size, tp_world_size)) + + if not use_v1: + conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) + + temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) + return conv_state_shape, temporal_state_shape + + @classmethod + def mamba2_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + n_groups: int, + num_heads: int, + head_dim: int, + state_size: int, + conv_kernel: int, + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (n_groups + + cls.extra_groups_for_head_shards(n_groups, tp_world_size)) + # heads and n_groups are TP-ed + conv_dim = intermediate_size + 2 * n_groups * state_size + + # contiguous along 'dim' axis + conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) + if not use_v1: + conv_state_shape = (divide(conv_dim, tp_world_size), conv_kernel - 1) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) + temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size) + return conv_state_shape, temporal_state_shape + + @classmethod + def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups \ No newline at end of file diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 0f5494427634..4a2ae07581f3 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -25,7 +25,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -457,7 +458,7 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.mamba_expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_n_groups, diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 6a58b1501fe6..85d64af5bd28 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -24,7 +24,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -543,7 +544,7 @@ def get_mamba_state_shape_from_config( if hf_config.mamba_d_ssm is None else hf_config.mamba_d_ssm) - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_n_groups, diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 59c1dce48ee7..e59502f12a1c 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -23,7 +23,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -547,7 +548,7 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.mamba_expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_n_groups, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index adad181617e6..75e92b01762d 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -19,7 +19,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -220,7 +221,7 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.n_groups, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 6a999e2254e7..eb62d5a53c1a 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -39,7 +39,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -482,7 +483,7 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.n_groups, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 7764fd9b9e08..4cb0becf302f 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -32,7 +32,8 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 -from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -869,7 +870,7 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.mamba_expand * hf_config.hidden_size - return get_mamba_state_shape( + return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.mamba_ngroups, From 2df9e5235d8c25bef9a161afc625932439ffe567 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 10:08:49 +0300 Subject: [PATCH 02/44] feat: Added Mamba1AttentionMetadata Signed-off-by: asafg --- .../layers/mamba/mamba1_metadata.py | 92 +++++++++++++ .../layers/mamba/mamba_mixer.py | 85 +++++++++--- vllm/v1/attention/backends/mamba1_attn.py | 129 ++++++++++++++++++ 3 files changed, 287 insertions(+), 19 deletions(-) create mode 100644 vllm/model_executor/layers/mamba/mamba1_metadata.py create mode 100644 vllm/v1/attention/backends/mamba1_attn.py diff --git a/vllm/model_executor/layers/mamba/mamba1_metadata.py b/vllm/model_executor/layers/mamba/mamba1_metadata.py new file mode 100644 index 000000000000..25c7a7a2bee1 --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba1_metadata.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Dict, Optional, Union + +import torch + +from vllm import envs +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.flash_attn import FlashAttentionMetadata + + +@dataclass +class Mamba1Metadata: + """Metadata for Mamba1 (original Mamba) implementation. + + This class contains metadata needed for the MambaMixer to operate in continuous + batching and prefill modes. The metadata is computed at top-level model forward + since it stays the same and is reused for all mamba layers in the same iteration. + """ + # Tensor indicating which sequences have initial states (context_lens > 0) + has_initial_states: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] + + # Tensor containing the starting location of each query in the sequence + query_start_loc: Optional[torch.Tensor] + + # Tensor containing indices for accessing the state cache + state_indices_tensor: Optional[torch.Tensor] + + # Number of prefill requests (request count) + num_prefills: int + + # Number of decode tokens (token count = request) + num_decode_tokens: int + + # Number of prefill tokens (token count) + num_prefill_tokens: int + + +def prepare_mamba1_metadata( + attn_metadata: Union[AttentionMetadata, Dict[str, AttentionMetadata]], + mamba1_metadata: Optional[Mamba1Metadata] = None, +) -> Mamba1Metadata: + """Prepare metadata for Mamba1 from attention metadata. + + Args: + attn_metadata: Attention metadata containing sequence information. + Can be either AttentionMetadata or a dict mapping layer prefix to AttentionMetadata. + mamba1_metadata: Optional existing metadata to update + + Returns: + Mamba1Metadata object with required fields populated + """ + # Handle dict case + if isinstance(attn_metadata, dict): + # Take the first value since all layers should have same metadata + attn_metadata = next(iter(attn_metadata.values())) + + # Get counts from attention metadata + num_prefills = attn_metadata.num_prefills + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + + # Get initial states info + if envs.VLLM_USE_V1: + has_initial_states = attn_metadata.context_lens_tensor > 0 + + # Get query start locations and state indices + query_start_loc = getattr(attn_metadata, 'query_start_loc', None) + state_indices_tensor = getattr(attn_metadata, 'state_indices_tensor', None) + + if mamba1_metadata is not None: + # Update existing metadata + mamba1_metadata.has_initial_states = has_initial_states + mamba1_metadata.query_start_loc = query_start_loc + mamba1_metadata.state_indices_tensor = state_indices_tensor + mamba1_metadata.num_prefills = num_prefills + mamba1_metadata.num_decode_tokens = num_decode_tokens + mamba1_metadata.num_prefill_tokens = num_prefill_tokens + return mamba1_metadata + + # Create new metadata + return Mamba1Metadata( + has_initial_states=has_initial_states, + query_start_loc=query_start_loc, + state_indices_tensor=state_indices_tensor, + num_prefills=num_prefills, + num_decode_tokens=num_decode_tokens, + num_prefill_tokens=num_prefill_tokens, + context_lens_tensor=attn_metadata.context_lens_tensor, + ) \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 60cf3e11885a..b6f1e86a7c90 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import torch from torch import nn from torch.nn.parameter import Parameter +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context @@ -14,12 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.mamba.mamba1_metadata import Mamba1Metadata from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @@ -47,7 +53,8 @@ def __init__(self, rms_norm_has_weight: bool = True, rms_norm_eps: float = 1e-5, activation="silu", - is_lora_enabled: bool = False): + is_lora_enabled: bool = False, + prefix: str = ""): super().__init__() self.time_step_rank = time_step_rank self.ssm_state_size = ssm_state_size @@ -131,14 +138,56 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): has_weight=rms_norm_has_weight, ) if use_rms_norm else None + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine + # Initialize with empty tensors in the correct format + # conv_state should be in (batch, width-1, dim) format + # so when transposed it becomes (batch, dim, width-1) + self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + + self.prefix = prefix + def forward_native(self, hidden_states: torch.Tensor, conv_state: torch.Tensor, ssm_state: torch.Tensor): pass def forward_cuda(self, hidden_states: torch.Tensor, - mamba_cache_params: MambaCacheParams): + mamba_cache_params: Optional[MambaCacheParams] = None, + mamba1_metadata: Optional[Mamba1Metadata] = None): - attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + attn_metadata: AttentionMetadata | Mamba1AttentionMetadata = get_forward_context().attn_metadata + + if envs.VLLM_USE_V1: + if attn_metadata is not None: + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.prefix] + + mamba1_metadata = attn_metadata + assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) + query_start_loc = mamba1_metadata.query_start_loc + state_indices_tensor = mamba1_metadata.state_indices_tensor + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # conv_state_raw : (batch, width-1, dim) + # ssm_state_raw : (batch, d_state, dim) + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1].transpose(-1, -2).contiguous() + has_initial_state = mamba1_metadata.has_initial_states + context_lens_tensor = mamba1_metadata.context_lens_tensor + else: + # For V0, we'll use the cache params and prepare metadata + assert mamba_cache_params is not None + conv_state = mamba_cache_params.conv_state + ssm_state = mamba_cache_params.ssm_state + state_indices_tensor = mamba_cache_params.state_indices_tensor + query_start_loc = attn_metadata.query_start_loc + context_lens_tensor = attn_metadata.context_lens_tensor + # context_lens_tensor = attn_metadata.seq_lens_tensor + if context_lens_tensor is not None: + has_initial_state = context_lens_tensor > 0 # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -148,8 +197,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: + if query_start_loc is not None and context_lens_tensor is not None: # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| @@ -161,18 +209,18 @@ def forward_cuda(self, hidden_states: torch.Tensor, conv_weights, bias=self.conv1d.bias, activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc) + conv_states=conv_state, + has_initial_state=has_initial_state, + cache_indices=state_indices_tensor, + query_start_loc=query_start_loc) else: hidden_states = causal_conv1d_update( hidden_states.transpose(0, 1), - mamba_cache_params.conv_state, + conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) + conv_state_indices=state_indices_tensor) hidden_states = hidden_states.transpose(0, 1) # 3. State Space Model sequence transformation @@ -203,11 +251,10 @@ def forward_cuda(self, hidden_states: torch.Tensor, time_proj_bias = (self.dt_proj.bias.float() if hasattr( self.dt_proj, "bias") else None) - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: + if query_start_loc is not None and context_lens_tensor is not None: scan_outputs = selective_scan_fn( hidden_states, - mamba_cache_params.ssm_state, + ssm_state, discrete_time_step, self.A, B.transpose(-2, -1), @@ -216,13 +263,13 @@ def forward_cuda(self, hidden_states: torch.Tensor, gate, time_proj_bias, delta_softplus=True, - cache_indices=mamba_cache_params.state_indices_tensor, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) + cache_indices=state_indices_tensor, + has_initial_state=has_initial_state, + query_start_loc=query_start_loc) else: scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) selective_state_update( - mamba_cache_params.ssm_state, + ssm_state, hidden_states.transpose(0, 1), discrete_time_step.transpose(0, 1), self.A, @@ -232,7 +279,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, gate.transpose(0, 1), time_proj_bias, dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor, + state_batch_indices=state_indices_tensor, out=scan_outputs) scan_outputs = scan_outputs.transpose(0, 1) diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py new file mode 100644 index 000000000000..66cf2b5b2d74 --- /dev/null +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionBackend +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + + +class Mamba1AttentionBackend(AttentionBackend): + + @staticmethod + def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: + return Mamba1AttentionMetadataBuilder + + +@dataclass +class Mamba1AttentionMetadata: + """ + Attention metadata for Mamba1 models. + + Mamba1 is simpler than Mamba2: + - No chunking/grouping + - No multi-head structure + - Simpler state management + """ + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int + query_start_loc: torch.Tensor # (batch+1,) cumulative offsets + seq_lens: torch.Tensor # (batch,) total lengths (computed + new) + context_lens_tensor: torch.Tensor # (batch,) already-computed tokens + state_indices_tensor: torch.Tensor # (batch,) one cache slot per request + has_initial_states: torch.Tensor # (batch,) bool mask + cu_seqlen: int # max_query_len, for buffer sizing + nums_dict: Optional[dict] = None + batch_ptr: Optional[torch.Tensor] = None + + +class Mamba1AttentionMetadataBuilder(AttentionMetadataBuilder[Mamba1AttentionMetadata]): + + def __init__( + self, + kv_cache_spec: AttentionSpec, + vllm_config: VllmConfig, + device: torch.device, + ): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + self.device = device + self.vllm_config = vllm_config + + def reorder_batch( + self, + input_batch: "InputBatch", + scheduler_output: "SchedulerOutput", + ) -> bool: + return reorder_batch_to_split_decodes_and_prefills( + input_batch, + scheduler_output, + decode_threshold=1, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> Mamba1AttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc # already on GPU + query_start_loc = query_start_loc.to(torch.int32) + + # Total sequence lengths (computed + new), on GPU, int32 + seq_lens = ( + common_attn_metadata.seq_lens + .to(query_start_loc.device) + .to(torch.int32) + ) + + # How many tokens were already computed per request (prefill), + # on GPU, int32 + context_lens = ( + common_attn_metadata.num_computed_tokens_cpu + .to(query_start_loc.device) + .to(torch.int32) + ) + + # Split out decode vs prefill phases + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=1, + ) + + # Which requests have any prior state + has_initial_states = (context_lens > 0) + + # One cache-slot index per request (like Mamba2), cast to int32 + state_indices = ( + common_attn_metadata.block_table_tensor[:, 0] + .to(query_start_loc.device) + .to(torch.int32) + ) + + return Mamba1AttentionMetadata( + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + context_lens_tensor=context_lens, + has_initial_states=has_initial_states.to(query_start_loc.device), + state_indices_tensor=state_indices, + cu_seqlen=common_attn_metadata.max_query_len, + ) \ No newline at end of file From 8ddb42e1595e4a300bd7619303954cef756e2c83 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 10:23:29 +0300 Subject: [PATCH 03/44] feat: Added V1 code to mamba Signed-off-by: asafg --- vllm/model_executor/models/mamba.py | 82 +++++++++++++++++------------ 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 8162ac3f7597..ae21a38b8a27 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -8,20 +8,24 @@ from torch import nn from transformers import MambaConfig +from vllm import envs from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba1_metadata import ( + Mamba1Metadata, prepare_mamba1_metadata) from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree, SupportsPP, - SupportsV0Only) + IsAttentionFree, SupportsPP) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -41,7 +45,8 @@ def __init__(self, config: MambaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - is_lora_enabled: Optional[bool] = False) -> None: + is_lora_enabled: Optional[bool] = False, + prefix: str = "") -> None: super().__init__() self.config = config self.is_falcon_mamba = config.model_type == "falcon_mamba" @@ -58,7 +63,8 @@ def __init__(self, rms_norm_has_weight=not self.is_falcon_mamba, rms_norm_eps=mixer_rms_eps, activation=config.hidden_act, - is_lora_enabled=self.is_lora_enabled) + is_lora_enabled=self.is_lora_enabled, + prefix=f"{prefix}.mixer") self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -67,6 +73,7 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, + mamba1_metadata: Mamba1Metadata, **kwargs, ): if residual is None: @@ -75,7 +82,7 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params) + hidden_states = self.mixer(hidden_states, mamba_cache_params, mamba1_metadata) return hidden_states, residual @@ -107,7 +114,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: MambaDecoderLayer(config, cache_config=cache_config, quant_config=quant_config, - is_lora_enabled=is_lora_enabled), + is_lora_enabled=is_lora_enabled, + prefix=prefix), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size, @@ -123,10 +131,17 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - mamba_cache_params: MambaCacheParams, + mamba_cache_params: Optional[MambaCacheParams] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: + mamba1_metadata = None + + if not envs.VLLM_USE_V1: + attn_metadata = get_forward_context().attn_metadata + mamba1_metadata = prepare_mamba1_metadata(attn_metadata) + + if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -140,12 +155,17 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] + + layer_cache_params = None + if mamba_cache_params is not None: + layer_cache_params = mamba_cache_params.at_layer_idx(i - self.start_layer) + hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, - mamba_cache_params=mamba_cache_params.at_layer_idx( - i - self.start_layer)) + mamba1_metadata=mamba1_metadata, + mamba_cache_params=layer_cache_params) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -176,8 +196,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, - SupportsV0Only): +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -227,14 +246,23 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) - - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + mamba_state_shape = MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=self.vllm_config.parallel_config.tensor_parallel_size, + intermediate_size=self.config.mamba_expand * self.config.hidden_size, + state_size=self.config.mamba_d_state, + conv_kernel=self.config.mamba_d_conv, + use_v1=False) + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, + num_mamba_layers, *mamba_state_shape) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) hidden_states = self.backbone(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) @@ -247,19 +275,7 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - conv_state_shape = ( - self.config.intermediate_size // world_size, - self.config.conv_kernel - 1, - ) - temporal_state_shape = ( - self.config.intermediate_size // world_size, - self.config.state_size, - ) - return conv_state_shape, temporal_state_shape + def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: From 52699f4aef057dbdd851d25ab0ceca60c3cb50f4 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 10:30:10 +0300 Subject: [PATCH 04/44] fix: Removed unnecessary mamba1 metadata dataclass Signed-off-by: asafg --- .../layers/mamba/mamba1_metadata.py | 92 ------------------- .../layers/mamba/mamba_mixer.py | 4 +- vllm/model_executor/models/mamba.py | 14 +-- 3 files changed, 2 insertions(+), 108 deletions(-) delete mode 100644 vllm/model_executor/layers/mamba/mamba1_metadata.py diff --git a/vllm/model_executor/layers/mamba/mamba1_metadata.py b/vllm/model_executor/layers/mamba/mamba1_metadata.py deleted file mode 100644 index 25c7a7a2bee1..000000000000 --- a/vllm/model_executor/layers/mamba/mamba1_metadata.py +++ /dev/null @@ -1,92 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass -from typing import Dict, Optional, Union - -import torch - -from vllm import envs -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.backends.flash_attn import FlashAttentionMetadata - - -@dataclass -class Mamba1Metadata: - """Metadata for Mamba1 (original Mamba) implementation. - - This class contains metadata needed for the MambaMixer to operate in continuous - batching and prefill modes. The metadata is computed at top-level model forward - since it stays the same and is reused for all mamba layers in the same iteration. - """ - # Tensor indicating which sequences have initial states (context_lens > 0) - has_initial_states: Optional[torch.Tensor] - context_lens_tensor: Optional[torch.Tensor] - - # Tensor containing the starting location of each query in the sequence - query_start_loc: Optional[torch.Tensor] - - # Tensor containing indices for accessing the state cache - state_indices_tensor: Optional[torch.Tensor] - - # Number of prefill requests (request count) - num_prefills: int - - # Number of decode tokens (token count = request) - num_decode_tokens: int - - # Number of prefill tokens (token count) - num_prefill_tokens: int - - -def prepare_mamba1_metadata( - attn_metadata: Union[AttentionMetadata, Dict[str, AttentionMetadata]], - mamba1_metadata: Optional[Mamba1Metadata] = None, -) -> Mamba1Metadata: - """Prepare metadata for Mamba1 from attention metadata. - - Args: - attn_metadata: Attention metadata containing sequence information. - Can be either AttentionMetadata or a dict mapping layer prefix to AttentionMetadata. - mamba1_metadata: Optional existing metadata to update - - Returns: - Mamba1Metadata object with required fields populated - """ - # Handle dict case - if isinstance(attn_metadata, dict): - # Take the first value since all layers should have same metadata - attn_metadata = next(iter(attn_metadata.values())) - - # Get counts from attention metadata - num_prefills = attn_metadata.num_prefills - num_decode_tokens = attn_metadata.num_decode_tokens - num_prefill_tokens = attn_metadata.num_prefill_tokens - - # Get initial states info - if envs.VLLM_USE_V1: - has_initial_states = attn_metadata.context_lens_tensor > 0 - - # Get query start locations and state indices - query_start_loc = getattr(attn_metadata, 'query_start_loc', None) - state_indices_tensor = getattr(attn_metadata, 'state_indices_tensor', None) - - if mamba1_metadata is not None: - # Update existing metadata - mamba1_metadata.has_initial_states = has_initial_states - mamba1_metadata.query_start_loc = query_start_loc - mamba1_metadata.state_indices_tensor = state_indices_tensor - mamba1_metadata.num_prefills = num_prefills - mamba1_metadata.num_decode_tokens = num_decode_tokens - mamba1_metadata.num_prefill_tokens = num_prefill_tokens - return mamba1_metadata - - # Create new metadata - return Mamba1Metadata( - has_initial_states=has_initial_states, - query_start_loc=query_start_loc, - state_indices_tensor=state_indices_tensor, - num_prefills=num_prefills, - num_decode_tokens=num_decode_tokens, - num_prefill_tokens=num_prefill_tokens, - context_lens_tensor=attn_metadata.context_lens_tensor, - ) \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index b6f1e86a7c90..c1bd362c2804 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.mamba.mamba1_metadata import Mamba1Metadata from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -156,8 +155,7 @@ def forward_native(self, hidden_states: torch.Tensor, pass def forward_cuda(self, hidden_states: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None, - mamba1_metadata: Optional[Mamba1Metadata] = None): + mamba_cache_params: Optional[MambaCacheParams] = None): attn_metadata: AttentionMetadata | Mamba1AttentionMetadata = get_forward_context().attn_metadata diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ae21a38b8a27..910bee567ff2 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -11,11 +11,8 @@ from vllm import envs from vllm.config import CacheConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba1_metadata import ( - Mamba1Metadata, prepare_mamba1_metadata) from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateShapeCalculator) @@ -73,7 +70,6 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, - mamba1_metadata: Mamba1Metadata, **kwargs, ): if residual is None: @@ -82,7 +78,7 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params, mamba1_metadata) + hidden_states = self.mixer(hidden_states, mamba_cache_params) return hidden_states, residual @@ -135,13 +131,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - mamba1_metadata = None - - if not envs.VLLM_USE_V1: - attn_metadata = get_forward_context().attn_metadata - mamba1_metadata = prepare_mamba1_metadata(attn_metadata) - - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -164,7 +153,6 @@ def forward( positions=positions, hidden_states=hidden_states, residual=residual, - mamba1_metadata=mamba1_metadata, mamba_cache_params=layer_cache_params) if not get_pp_group().is_last_rank: return IntermediateTensors({ From 57f7316bcb0158f8d466e20ecec4013ac0ae8b90 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 11:54:41 +0300 Subject: [PATCH 05/44] fix: Updated configs in mamba Signed-off-by: asafg --- .../layers/mamba/mamba_mixer.py | 15 +++++++++- vllm/model_executor/models/mamba.py | 6 ++-- vllm/v1/attention/backends/mamba1_attn.py | 29 +++---------------- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index c1bd362c2804..848017a9610d 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -18,6 +18,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -60,6 +62,8 @@ def __init__(self, self.use_rms_norm = use_rms_norm self.activation = activation self.is_lora_enabled = is_lora_enabled + self.conv_kernel_size = conv_kernel_size + self.intermediate_size = intermediate_size self.conv1d = ColumnParallelLinear( input_size=conv_kernel_size, @@ -174,7 +178,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1].transpose(-1, -2).contiguous() has_initial_state = mamba1_metadata.has_initial_states - context_lens_tensor = mamba1_metadata.context_lens_tensor + context_lens_tensor = mamba1_metadata.seq_lens else: # For V0, we'll use the cache params and prepare metadata assert mamba_cache_params is not None @@ -290,3 +294,12 @@ def forward_cuda(self, hidden_states: torch.Tensor, contextualized_states = self.out_proj( scan_outputs.transpose(-2, -1))[0] return contextualized_states + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=get_tensor_model_parallel_world_size(), + intermediate_size=self.intermediate_size, + state_size=self.ssm_state_size, + conv_kernel=self.conv_kernel_size, + use_v1=envs.VLLM_USE_V1, + ) \ No newline at end of file diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 910bee567ff2..4cad803db0c6 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -242,9 +242,9 @@ def forward(self, self.vllm_config.parallel_config, LayerBlockType.mamba) mamba_state_shape = MambaStateShapeCalculator.mamba1_state_shape( tp_world_size=self.vllm_config.parallel_config.tensor_parallel_size, - intermediate_size=self.config.mamba_expand * self.config.hidden_size, - state_size=self.config.mamba_d_state, - conv_kernel=self.config.mamba_d_conv, + intermediate_size=self.config.intermediate_size, + state_size=self.config.state_size, + conv_kernel=self.config.conv_kernel, use_v1=False) self.mamba_cache = MambaCacheManager( self.vllm_config, self.lm_head.weight.dtype, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 66cf2b5b2d74..1a6706adefca 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -41,7 +41,6 @@ class Mamba1AttentionMetadata: num_decode_tokens: int query_start_loc: torch.Tensor # (batch+1,) cumulative offsets seq_lens: torch.Tensor # (batch,) total lengths (computed + new) - context_lens_tensor: torch.Tensor # (batch,) already-computed tokens state_indices_tensor: torch.Tensor # (batch,) one cache slot per request has_initial_states: torch.Tensor # (batch,) bool mask cu_seqlen: int # max_query_len, for buffer sizing @@ -79,41 +78,22 @@ def build( common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> Mamba1AttentionMetadata: - num_reqs = common_attn_metadata.num_reqs - query_start_loc = common_attn_metadata.query_start_loc # already on GPU - query_start_loc = query_start_loc.to(torch.int32) + query_start_loc = common_attn_metadata.query_start_loc - # Total sequence lengths (computed + new), on GPU, int32 seq_lens = ( common_attn_metadata.seq_lens .to(query_start_loc.device) .to(torch.int32) ) - # How many tokens were already computed per request (prefill), - # on GPU, int32 - context_lens = ( - common_attn_metadata.num_computed_tokens_cpu - .to(query_start_loc.device) - .to(torch.int32) - ) - - # Split out decode vs prefill phases num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills( common_attn_metadata, decode_threshold=1, ) - # Which requests have any prior state - has_initial_states = (context_lens > 0) - - # One cache-slot index per request (like Mamba2), cast to int32 - state_indices = ( - common_attn_metadata.block_table_tensor[:, 0] - .to(query_start_loc.device) - .to(torch.int32) - ) + has_initial_states = (seq_lens > 0) + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] return Mamba1AttentionMetadata( num_prefills=num_prefills, @@ -122,8 +102,7 @@ def build( num_decode_tokens=num_decode_tokens, query_start_loc=query_start_loc, seq_lens=seq_lens, - context_lens_tensor=context_lens, has_initial_states=has_initial_states.to(query_start_loc.device), - state_indices_tensor=state_indices, + state_indices_tensor=state_indices_tensor, cu_seqlen=common_attn_metadata.max_query_len, ) \ No newline at end of file From 6f4b1db10e2d61bf0a4b578819bb90ac8544e80a Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 12:34:07 +0300 Subject: [PATCH 06/44] refactor: Added v1 condition Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_mixer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 848017a9610d..a742a89b687d 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -173,21 +173,18 @@ def forward_cuda(self, hidden_states: torch.Tensor, query_start_loc = mamba1_metadata.query_start_loc state_indices_tensor = mamba1_metadata.state_indices_tensor self_kv_cache = self.kv_cache[forward_context.virtual_engine] - # conv_state_raw : (batch, width-1, dim) - # ssm_state_raw : (batch, d_state, dim) conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1].transpose(-1, -2).contiguous() + ssm_state = self_kv_cache[1].contiguous() has_initial_state = mamba1_metadata.has_initial_states context_lens_tensor = mamba1_metadata.seq_lens else: - # For V0, we'll use the cache params and prepare metadata assert mamba_cache_params is not None conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state state_indices_tensor = mamba_cache_params.state_indices_tensor query_start_loc = attn_metadata.query_start_loc context_lens_tensor = attn_metadata.context_lens_tensor - # context_lens_tensor = attn_metadata.seq_lens_tensor + if context_lens_tensor is not None: has_initial_state = context_lens_tensor > 0 @@ -199,6 +196,11 @@ def forward_cuda(self, hidden_states: torch.Tensor, conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if envs.VLLM_USE_V1 and attn_metadata is None: + # V1 profile run + hidden_states = hidden_states.contiguous() + return self.out_proj(hidden_states.transpose(-2, -1))[0] + if query_start_loc is not None and context_lens_tensor is not None: # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| From 5165e0a3d132ab4d1ced98207837865fd68d56ea Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 12:54:53 +0300 Subject: [PATCH 07/44] fix: Removed unnecesary ignore Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 818c503fff4e..8ca97dbc6af6 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -4,7 +4,7 @@ class MambaStateShapeCalculator: - @classmethod # type: ignore + @classmethod def mamba1_state_shape( cls, tp_world_size: int, From c4a3bcdd2551fa87f509930d8993d1608edc0bce Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 13:23:16 +0300 Subject: [PATCH 08/44] feat: Added mamba type to identify mamba version Signed-off-by: asafg --- vllm/v1/attention/backends/mamba1_attn.py | 10 +++++----- vllm/v1/kv_cache_interface.py | 4 ++++ vllm/v1/worker/gpu_model_runner.py | 9 ++++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 1a6706adefca..af54d3776153 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -39,11 +39,11 @@ class Mamba1AttentionMetadata: num_prefill_tokens: int num_decodes: int num_decode_tokens: int - query_start_loc: torch.Tensor # (batch+1,) cumulative offsets - seq_lens: torch.Tensor # (batch,) total lengths (computed + new) - state_indices_tensor: torch.Tensor # (batch,) one cache slot per request - has_initial_states: torch.Tensor # (batch,) bool mask - cu_seqlen: int # max_query_len, for buffer sizing + query_start_loc: torch.Tensor + seq_lens: torch.Tensor + state_indices_tensor: torch.Tensor + has_initial_states: torch.Tensor + cu_seqlen: int nums_dict: Optional[dict] = None batch_ptr: Optional[torch.Tensor] = None diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4ff96f9786b8..e65352ddf6b9 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -178,6 +178,10 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # window [CDEF] of 6 tokens. return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes +class MambaType(str, Enum): + MAMBA1 = "mamba1" + MAMBA2 = "mamba2" + @dataclass(frozen=True) class MambaSpec(KVCacheSpec): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 041687ae28b2..c5ae6ef20852 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -28,6 +28,7 @@ prepare_communication_buffer_for_model) from vllm.forward_context import DPMetadata, set_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader @@ -45,6 +46,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, @@ -55,7 +57,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, + KVCacheSpec, MambaSpec, MambaType, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -3015,6 +3017,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): + if isinstance(mamba_module, MambaMixer): + mamba_type = MambaType.MAMBA1 + else: + mamba_type = MambaType.MAMBA2 + kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), dtype=self.kv_cache_dtype, From 476ba5b74b7de2d67fc6898c7ded183de3b10171 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 13:30:39 +0300 Subject: [PATCH 09/44] fix: Added mamba_type property to mamba_base Signed-off-by: asafg --- vllm/model_executor/layers/mamba/abstract.py | 2 ++ vllm/model_executor/layers/mamba/mamba_mixer.py | 10 ++++++++-- vllm/model_executor/layers/mamba/mamba_mixer2.py | 1 + vllm/v1/worker/gpu_model_runner.py | 6 ------ 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index daebe46f6f77..0c90c062053a 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -5,6 +5,8 @@ import torch +from vllm.v1.kv_cache_interface import MambaType + class MambaBase(ABC): """ diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index a742a89b687d..1fa21eac10a1 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( @@ -27,11 +28,12 @@ from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata +from vllm.v1.kv_cache_interface import MambaType # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @CustomOp.register("mamba_mixer") -class MambaMixer(CustomOp): +class MambaMixer(MambaBase, CustomOp): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent @@ -304,4 +306,8 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: state_size=self.ssm_state_size, conv_kernel=self.conv_kernel_size, use_v1=envs.VLLM_USE_V1, - ) \ No newline at end of file + ) + + @property + def mamba_type(self) -> MambaType: + return MambaType.MAMBA1 \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 9a5c22a3a185..023ef77f08e0 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -37,6 +37,7 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata +from vllm.v1.kv_cache_interface import MambaType # Added by the IBM Team, 2024 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c5ae6ef20852..4d54b4772c48 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -28,7 +28,6 @@ prepare_communication_buffer_for_model) from vllm.forward_context import DPMetadata, set_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader @@ -3017,11 +3016,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): - if isinstance(mamba_module, MambaMixer): - mamba_type = MambaType.MAMBA1 - else: - mamba_type = MambaType.MAMBA2 - kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), dtype=self.kv_cache_dtype, From e747f2747801783a40e626251098356bb0a9b213 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 14:20:00 +0300 Subject: [PATCH 10/44] fix: Lint Signed-off-by: asafg --- .../layers/mamba/mamba_mixer.py | 10 +++++---- .../layers/mamba/mamba_mixer2.py | 1 + .../layers/mamba/mamba_utils.py | 22 ++++++++++++------- vllm/model_executor/models/mamba.py | 18 ++++++++------- vllm/v1/attention/backends/mamba1_attn.py | 12 +++++----- vllm/v1/kv_cache_interface.py | 1 + vllm/v1/worker/gpu_model_runner.py | 6 +++-- 7 files changed, 41 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 1fa21eac10a1..1413a674ed0c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -160,10 +160,12 @@ def forward_native(self, hidden_states: torch.Tensor, conv_state: torch.Tensor, ssm_state: torch.Tensor): pass - def forward_cuda(self, hidden_states: torch.Tensor, + def forward_cuda(self, + hidden_states: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): - attn_metadata: AttentionMetadata | Mamba1AttentionMetadata = get_forward_context().attn_metadata + attn_metadata: AttentionMetadata | Mamba1AttentionMetadata = + get_forward_context().attn_metadata if envs.VLLM_USE_V1: if attn_metadata is not None: @@ -186,7 +188,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, state_indices_tensor = mamba_cache_params.state_indices_tensor query_start_loc = attn_metadata.query_start_loc context_lens_tensor = attn_metadata.context_lens_tensor - + if context_lens_tensor is not None: has_initial_state = context_lens_tensor > 0 @@ -310,4 +312,4 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: @property def mamba_type(self) -> MambaType: - return MambaType.MAMBA1 \ No newline at end of file + return MambaType.MAMBA1 diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 023ef77f08e0..999c1e4ae9b3 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -748,6 +748,7 @@ def mamba_type(self) -> str: return "mamba2" + def mamba_mixer2( hidden_states: torch.Tensor, output: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 8ca97dbc6af6..8a7b1220b723 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -4,6 +4,7 @@ class MambaStateShapeCalculator: + @classmethod def mamba1_state_shape( cls, @@ -13,12 +14,15 @@ def mamba1_state_shape( conv_kernel: int, use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int]]: - conv_state_shape = (conv_kernel - 1, divide(intermediate_size, tp_world_size)) - + conv_state_shape = (conv_kernel - 1, + divide(intermediate_size, tp_world_size)) + if not use_v1: - conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) - - temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) + conv_state_shape = (divide(intermediate_size, + tp_world_size), conv_kernel - 1) + + temporal_state_shape = (divide(intermediate_size, + tp_world_size), state_size) return conv_state_shape, temporal_state_shape @classmethod @@ -43,12 +47,14 @@ def mamba2_state_shape( # contiguous along 'dim' axis conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) if not use_v1: - conv_state_shape = (divide(conv_dim, tp_world_size), conv_kernel - 1) + conv_state_shape = (divide(conv_dim, + tp_world_size), conv_kernel - 1) # These are not TP-ed as they depend on A, dt_bias, D # - they are typically small # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) - temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size) + temporal_state_shape = (divide(num_heads, + tp_world_size), head_dim, state_size) return conv_state_shape, temporal_state_shape @classmethod @@ -61,4 +67,4 @@ def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): return 0 # for n_groups == 1, this is exactly tp_size - n_groups - return tp_size - ngroups \ No newline at end of file + return tp_size - ngroups diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 4cad803db0c6..a9f60b1eccd0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -111,7 +111,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config=cache_config, quant_config=quant_config, is_lora_enabled=is_lora_enabled, - prefix=prefix), + prefix=prefix), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size, @@ -147,7 +147,8 @@ def forward( layer_cache_params = None if mamba_cache_params is not None: - layer_cache_params = mamba_cache_params.at_layer_idx(i - self.start_layer) + layer_cache_params = mamba_cache_params.at_layer_idx( + i - self.start_layer) hidden_states, residual = layer( positions=positions, @@ -234,21 +235,23 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - + mamba_cache_params = None if not envs.VLLM_USE_V1: if self.mamba_cache is None: num_mamba_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) mamba_state_shape = MambaStateShapeCalculator.mamba1_state_shape( - tp_world_size=self.vllm_config.parallel_config.tensor_parallel_size, + tp_world_size=self.vllm_config.parallel_config. + tensor_parallel_size, intermediate_size=self.config.intermediate_size, state_size=self.config.state_size, conv_kernel=self.config.conv_kernel, use_v1=False) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, - num_mamba_layers, *mamba_state_shape) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_mamba_layers, + *mamba_state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -263,7 +266,6 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index af54d3776153..06460bfb846f 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -48,7 +48,8 @@ class Mamba1AttentionMetadata: batch_ptr: Optional[torch.Tensor] = None -class Mamba1AttentionMetadataBuilder(AttentionMetadataBuilder[Mamba1AttentionMetadata]): +class Mamba1AttentionMetadataBuilder( + AttentionMetadataBuilder[Mamba1AttentionMetadata]): def __init__( self, @@ -80,11 +81,8 @@ def build( ) -> Mamba1AttentionMetadata: query_start_loc = common_attn_metadata.query_start_loc - seq_lens = ( - common_attn_metadata.seq_lens - .to(query_start_loc.device) - .to(torch.int32) - ) + seq_lens = (common_attn_metadata.seq_lens.to( + query_start_loc.device).to(torch.int32)) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills( @@ -105,4 +103,4 @@ def build( has_initial_states=has_initial_states.to(query_start_loc.device), state_indices_tensor=state_indices_tensor, cu_seqlen=common_attn_metadata.max_query_len, - ) \ No newline at end of file + ) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index e65352ddf6b9..d719cdce57b0 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -178,6 +178,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # window [CDEF] of 6 tokens. return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes + class MambaType(str, Enum): MAMBA1 = "mamba1" MAMBA2 = "mamba2" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4d54b4772c48..f2476d23cba6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2592,7 +2592,8 @@ def _initialize_single_attn_backend( "Non-Attention backend is not supported by V1 " "GPUModelRunner.") elif isinstance(kv_cache_spec, MambaSpec): - attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type) + attn_backend_i = get_mamba_attn_backend( + kv_cache_spec.mamba_type) else: raise ValueError( f"Unknown KV cache spec type: {type(kv_cache_spec)}") @@ -2676,7 +2677,8 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.attn_metadata_builders.append(attn_metadata_builder) self.is_encoder_only_model = True - def calculate_reorder_batch_threshold(self) -> None: + def calculate_reorder_batch_threshold( + self) -> None: """ Check that if any backends reorder batches; that the reordering is compatible (e.g., decode threshold is the same) From 928927397577ee338081884d64f856231056dd8d Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 14:23:47 +0300 Subject: [PATCH 11/44] fix: Lint Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_mixer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 1413a674ed0c..190bd39a944c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -8,11 +8,10 @@ from torch.nn.parameter import Parameter from vllm import envs -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -164,8 +163,8 @@ def forward_cuda(self, hidden_states: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): - attn_metadata: AttentionMetadata | Mamba1AttentionMetadata = - get_forward_context().attn_metadata + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata if envs.VLLM_USE_V1: if attn_metadata is not None: From 3402d3a0ffe35463b46efda5b069b7ead49cb004 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 20 Jul 2025 15:51:08 +0300 Subject: [PATCH 12/44] fix: Ruff long lines Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 3 ++- vllm/model_executor/models/mamba.py | 7 +++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 999c1e4ae9b3..49b12a34204a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -279,8 +279,9 @@ def __init__( # - for TP we shard conv_dim by sharding on n_groups, # - but if n_groups cannot divide tp_size, we need to # extend some extra groups - self.n_groups = n_groups + MambaStateShapeCalculator.extra_groups_for_head_shards( + groups = MambaStateShapeCalculator.extra_groups_for_head_shards( n_groups, self.tp_size) + self.n_groups = n_groups + groups self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size self.conv1d = ColumnParallelLinear( diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index a9f60b1eccd0..ae2b4aea81a9 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -239,9 +239,9 @@ def forward(self, mamba_cache_params = None if not envs.VLLM_USE_V1: if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( + num_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) - mamba_state_shape = MambaStateShapeCalculator.mamba1_state_shape( + state_shape = MambaStateShapeCalculator.mamba1_state_shape( tp_world_size=self.vllm_config.parallel_config. tensor_parallel_size, intermediate_size=self.config.intermediate_size, @@ -250,8 +250,7 @@ def forward(self, use_v1=False) self.mamba_cache = MambaCacheManager(self.vllm_config, self.lm_head.weight.dtype, - num_mamba_layers, - *mamba_state_shape) + num_layers, *state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) From 4d540128bd067f635ece0274c89bda45cea36609 Mon Sep 17 00:00:00 2001 From: asafg Date: Mon, 21 Jul 2025 23:15:21 +0300 Subject: [PATCH 13/44] fix: Added context_lens_tensor Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_mixer.py | 2 +- vllm/v1/attention/backends/mamba1_attn.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 190bd39a944c..de83c3e73622 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -179,7 +179,7 @@ def forward_cuda(self, conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1].contiguous() has_initial_state = mamba1_metadata.has_initial_states - context_lens_tensor = mamba1_metadata.seq_lens + context_lens_tensor = mamba1_metadata.context_lens_tensor else: assert mamba_cache_params is not None conv_state = mamba_cache_params.conv_state diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 06460bfb846f..9fc90235e79f 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -41,6 +41,7 @@ class Mamba1AttentionMetadata: num_decode_tokens: int query_start_loc: torch.Tensor seq_lens: torch.Tensor + context_lens_tensor: torch.Tensor state_indices_tensor: torch.Tensor has_initial_states: torch.Tensor cu_seqlen: int @@ -90,8 +91,10 @@ def build( decode_threshold=1, ) - has_initial_states = (seq_lens > 0) state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( + query_start_loc.device) + has_initial_states = (context_lens_tensor > 0) return Mamba1AttentionMetadata( num_prefills=num_prefills, @@ -99,6 +102,7 @@ def build( num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, query_start_loc=query_start_loc, + context_lens_tensor=context_lens_tensor, seq_lens=seq_lens, has_initial_states=has_initial_states.to(query_start_loc.device), state_indices_tensor=state_indices_tensor, From 898306fe9a3709acd9061c56f45e693ef7b1b050 Mon Sep 17 00:00:00 2001 From: asafg Date: Tue, 22 Jul 2025 11:49:23 +0300 Subject: [PATCH 14/44] feat: Updated jamba code to support v1 Signed-off-by: asafg --- vllm/model_executor/models/jamba.py | 51 ++++++++++++++++++----------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 263f4c8379cf..071e1aa40d3b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -8,6 +8,7 @@ from torch import nn from transformers import JambaConfig +from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -112,7 +113,8 @@ def __init__(self, use_rms_norm=True, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, - is_lora_enabled = self.is_lora_enabled + is_lora_enabled = self.is_lora_enabled, + prefix=f"{prefix}.mixer", ) num_experts = config.layers_num_experts[layer_idx] @@ -344,7 +346,8 @@ def forward( layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): kv_cache_index += 1 - if isinstance(layer, JambaMambaDecoderLayer): + if isinstance(layer, + JambaMambaDecoderLayer) and mamba_cache_params: current_state_layer = mamba_cache_index layer_mamba_cache_params = mamba_cache_params.at_layer_idx( current_state_layer) @@ -441,6 +444,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params +# TODO: Remove SupportsV0Only once v1 is fully supported class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsV0Only): hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ @@ -509,12 +513,17 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if self.mamba_cache is None: - num_mamba_layers = self.model_config.get_num_layers_by_block_type( + # NOTE: mamba_cache_params is not needed for v1 + mamba_cache_params = None + + if not envs.VLLM_USE_V1 and self.mamba_cache is None: + num_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, - *self._get_mamba_cache_shape()) + state_shape = self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_layers, *state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -529,19 +538,23 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - conv_state_shape = ( - self.config.mamba_expand * hidden_size // world_size, - self.config.mamba_d_conv - 1, - ) - temporal_state_shape = ( - self.config.mamba_expand * hidden_size // world_size, - self.config.mamba_d_state, + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + hidden_size = hf_config.hidden_size + + return MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.mamba_expand * hidden_size, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, ) - return conv_state_shape, temporal_state_shape def compute_logits( self, From 60c1840cb7200d521a74e2dab8fee9a95364cc99 Mon Sep 17 00:00:00 2001 From: asafg Date: Tue, 22 Jul 2025 13:44:08 +0300 Subject: [PATCH 15/44] fix: CR changes Signed-off-by: asafg --- vllm/attention/mamba_selectors.py | 14 ++++++++++++++ vllm/model_executor/layers/mamba/abstract.py | 2 -- vllm/model_executor/layers/mamba/mamba_mixer.py | 16 +++++++++------- vllm/model_executor/layers/mamba/mamba_mixer2.py | 2 -- vllm/v1/kv_cache_interface.py | 5 ----- vllm/v1/worker/gpu_model_runner.py | 5 +++-- 6 files changed, 26 insertions(+), 18 deletions(-) create mode 100644 vllm/attention/mamba_selectors.py diff --git a/vllm/attention/mamba_selectors.py b/vllm/attention/mamba_selectors.py new file mode 100644 index 000000000000..3a592bd8f16b --- /dev/null +++ b/vllm/attention/mamba_selectors.py @@ -0,0 +1,14 @@ +from vllm.attention.backends.abstract import AttentionBackend +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend +from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend + + +def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: + if mamba_type == "mamba1": + return Mamba1AttentionBackend + + if mamba_type == "mamba2": + return Mamba2AttentionBackend + + raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " + "supported yet.") \ No newline at end of file diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 0c90c062053a..daebe46f6f77 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -5,8 +5,6 @@ import torch -from vllm.v1.kv_cache_interface import MambaType - class MambaBase(ABC): """ diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index de83c3e73622..4207a2f6067b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -27,7 +27,6 @@ from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata -from vllm.v1.kv_cache_interface import MambaType # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @@ -147,10 +146,10 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - # The outer list is for v0 PP virtual engine - # Initialize with empty tensors in the correct format - # conv_state should be in (batch, width-1, dim) format - # so when transposed it becomes (batch, dim, width-1) + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + # The inner tuple is (conv_state, ssm_state) self.kv_cache = [(torch.tensor([]), torch.tensor([]))] self.prefix = prefix @@ -176,6 +175,9 @@ def forward_cuda(self, query_start_loc = mamba1_metadata.query_start_loc state_indices_tensor = mamba1_metadata.state_indices_tensor self_kv_cache = self.kv_cache[forward_context.virtual_engine] + + # conv_state should be in (batch, width-1, dim) format + # so when transposed it becomes (batch, dim, width-1) conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1].contiguous() has_initial_state = mamba1_metadata.has_initial_states @@ -310,5 +312,5 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: ) @property - def mamba_type(self) -> MambaType: - return MambaType.MAMBA1 + def mamba_type(self) -> str: + return "mamba1" diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 49b12a34204a..d5f4877135c9 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -37,7 +37,6 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata -from vllm.v1.kv_cache_interface import MambaType # Added by the IBM Team, 2024 @@ -749,7 +748,6 @@ def mamba_type(self) -> str: return "mamba2" - def mamba_mixer2( hidden_states: torch.Tensor, output: torch.Tensor, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index d719cdce57b0..4ff96f9786b8 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -179,11 +179,6 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes -class MambaType(str, Enum): - MAMBA1 = "mamba1" - MAMBA2 = "mamba2" - - @dataclass(frozen=True) class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f2476d23cba6..69cb761e547a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -17,6 +17,7 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention +from vllm.attention.mamba_selectors import get_mamba_attn_backend from vllm.compilation.counter import compilation_counter from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config, update_config) @@ -56,7 +57,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, MambaType, + KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -2592,7 +2593,7 @@ def _initialize_single_attn_backend( "Non-Attention backend is not supported by V1 " "GPUModelRunner.") elif isinstance(kv_cache_spec, MambaSpec): - attn_backend_i = get_mamba_attn_backend( + attn_backend_i = get_mamba_attn_backend( kv_cache_spec.mamba_type) else: raise ValueError( From f0566bf0eee480cef698132d684689773668dd00 Mon Sep 17 00:00:00 2001 From: asafg Date: Tue, 22 Jul 2025 13:51:24 +0300 Subject: [PATCH 16/44] fix: Conflicts Signed-off-by: asafg --- vllm/model_executor/models/jamba.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 071e1aa40d3b..2001d21d0bd6 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -20,6 +20,8 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateShapeCalculator) from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig From 4a4f9b1915d6fcb859b5a51539596c286509e48b Mon Sep 17 00:00:00 2001 From: asafg Date: Tue, 22 Jul 2025 15:13:51 +0300 Subject: [PATCH 17/44] fix: Lint Signed-off-by: asafg --- vllm/attention/mamba_selectors.py | 6 ++++-- vllm/model_executor/models/jamba.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/attention/mamba_selectors.py b/vllm/attention/mamba_selectors.py index 3a592bd8f16b..f56f2fb7bf69 100644 --- a/vllm/attention/mamba_selectors.py +++ b/vllm/attention/mamba_selectors.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend @@ -6,9 +8,9 @@ def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: if mamba_type == "mamba1": return Mamba1AttentionBackend - + if mamba_type == "mamba2": return Mamba2AttentionBackend raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " - "supported yet.") \ No newline at end of file + "supported yet.") diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 2001d21d0bd6..bc6dcccfa92b 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -515,9 +515,6 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - # NOTE: mamba_cache_params is not needed for v1 - mamba_cache_params = None - if not envs.VLLM_USE_V1 and self.mamba_cache is None: num_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) @@ -527,7 +524,10 @@ def forward(self, self.lm_head.weight.dtype, num_layers, *state_shape) - mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + else: + # NOTE: mamba_cache_params is not needed for v1 + mamba_cache_params = None hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) From 1b1bad20d501306f583c27278c56273f0e3cb2be Mon Sep 17 00:00:00 2001 From: asafg Date: Tue, 22 Jul 2025 15:19:39 +0300 Subject: [PATCH 18/44] fix: Jamba forward Signed-off-by: asafg --- vllm/model_executor/models/jamba.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index bc6dcccfa92b..326671d1238a 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -515,19 +515,19 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): - if not envs.VLLM_USE_V1 and self.mamba_cache is None: - num_layers = self.model_config.get_num_layers_by_block_type( - self.vllm_config.parallel_config, LayerBlockType.mamba) - state_shape = self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) - self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, - num_layers, *state_shape) + # NOTE: mamba_cache_params is not needed for v1 + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) + state_shape = self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_layers, *state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) - else: - # NOTE: mamba_cache_params is not needed for v1 - mamba_cache_params = None hidden_states = self.model(input_ids, positions, mamba_cache_params, intermediate_tensors, inputs_embeds) From f6d9311038f61f08baf8c91a28c7f260b9a9d4de Mon Sep 17 00:00:00 2001 From: asafg Date: Tue, 22 Jul 2025 23:03:59 +0300 Subject: [PATCH 19/44] refactor: Removed unnecessary fields Signed-off-by: asafg --- vllm/v1/attention/backends/mamba1_attn.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 9fc90235e79f..4cb376b18204 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import torch @@ -27,14 +27,6 @@ def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: @dataclass class Mamba1AttentionMetadata: - """ - Attention metadata for Mamba1 models. - - Mamba1 is simpler than Mamba2: - - No chunking/grouping - - No multi-head structure - - Simpler state management - """ num_prefills: int num_prefill_tokens: int num_decodes: int @@ -44,9 +36,6 @@ class Mamba1AttentionMetadata: context_lens_tensor: torch.Tensor state_indices_tensor: torch.Tensor has_initial_states: torch.Tensor - cu_seqlen: int - nums_dict: Optional[dict] = None - batch_ptr: Optional[torch.Tensor] = None class Mamba1AttentionMetadataBuilder( @@ -106,5 +95,4 @@ def build( seq_lens=seq_lens, has_initial_states=has_initial_states.to(query_start_loc.device), state_indices_tensor=state_indices_tensor, - cu_seqlen=common_attn_metadata.max_query_len, ) From dd098a4ed86467ebf41f6cbe8438cf9a31148dbf Mon Sep 17 00:00:00 2001 From: asafg Date: Wed, 23 Jul 2025 09:58:35 +0300 Subject: [PATCH 20/44] refactor: Moved mamba_selectors Signed-off-by: asafg --- vllm/attention/mamba_selectors.py | 16 ---------------- vllm/v1/attention/backends/mamba1_attn.py | 2 +- vllm/v1/attention/backends/mamba_selectors.py | 4 ++++ vllm/v1/worker/gpu_model_runner.py | 1 - 4 files changed, 5 insertions(+), 18 deletions(-) delete mode 100644 vllm/attention/mamba_selectors.py diff --git a/vllm/attention/mamba_selectors.py b/vllm/attention/mamba_selectors.py deleted file mode 100644 index f56f2fb7bf69..000000000000 --- a/vllm/attention/mamba_selectors.py +++ /dev/null @@ -1,16 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend -from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend - - -def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: - if mamba_type == "mamba1": - return Mamba1AttentionBackend - - if mamba_type == "mamba2": - return Mamba2AttentionBackend - - raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " - "supported yet.") diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 4cb376b18204..7d482b37a0ab 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -93,6 +93,6 @@ def build( query_start_loc=query_start_loc, context_lens_tensor=context_lens_tensor, seq_lens=seq_lens, - has_initial_states=has_initial_states.to(query_start_loc.device), + has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, ) diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py index 80021a216556..f56f2fb7bf69 100644 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ b/vllm/v1/attention/backends/mamba_selectors.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.attention.backends.abstract import AttentionBackend +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: + if mamba_type == "mamba1": + return Mamba1AttentionBackend + if mamba_type == "mamba2": return Mamba2AttentionBackend diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 69cb761e547a..6e12059c3875 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -17,7 +17,6 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention -from vllm.attention.mamba_selectors import get_mamba_attn_backend from vllm.compilation.counter import compilation_counter from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config, update_config) From 364ea414e31921ea1bb450d3c624856a67ef4efe Mon Sep 17 00:00:00 2001 From: asafg Date: Wed, 23 Jul 2025 17:03:54 +0300 Subject: [PATCH 21/44] refactor: Removed v1 from mamab1 state shape Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_mixer.py | 8 ++------ vllm/model_executor/layers/mamba/mamba_utils.py | 9 ++------- vllm/model_executor/models/jamba.py | 2 -- vllm/model_executor/models/mamba.py | 3 +-- 4 files changed, 5 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 4207a2f6067b..07603401d675 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -175,11 +175,8 @@ def forward_cuda(self, query_start_loc = mamba1_metadata.query_start_loc state_indices_tensor = mamba1_metadata.state_indices_tensor self_kv_cache = self.kv_cache[forward_context.virtual_engine] - - # conv_state should be in (batch, width-1, dim) format - # so when transposed it becomes (batch, dim, width-1) - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1].contiguous() + conv_state = self_kv_cache[0] + ssm_state = self_kv_cache[1] has_initial_state = mamba1_metadata.has_initial_states context_lens_tensor = mamba1_metadata.context_lens_tensor else: @@ -308,7 +305,6 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: intermediate_size=self.intermediate_size, state_size=self.ssm_state_size, conv_kernel=self.conv_kernel_size, - use_v1=envs.VLLM_USE_V1, ) @property diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 8a7b1220b723..6d655d15edfc 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -12,14 +12,9 @@ def mamba1_state_shape( intermediate_size: int, state_size: int, conv_kernel: int, - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int]]: - conv_state_shape = (conv_kernel - 1, - divide(intermediate_size, tp_world_size)) - - if not use_v1: - conv_state_shape = (divide(intermediate_size, - tp_world_size), conv_kernel - 1) + conv_state_shape = (divide(intermediate_size, + tp_world_size), conv_kernel - 1) temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 326671d1238a..c4dc57a7069f 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -544,7 +544,6 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", - use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int]]: parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config @@ -555,7 +554,6 @@ def get_mamba_state_shape_from_config( intermediate_size=hf_config.mamba_expand * hidden_size, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, - use_v1=use_v1, ) def compute_logits( diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ae2b4aea81a9..8096757f6dac 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -246,8 +246,7 @@ def forward(self, tensor_parallel_size, intermediate_size=self.config.intermediate_size, state_size=self.config.state_size, - conv_kernel=self.config.conv_kernel, - use_v1=False) + conv_kernel=self.config.conv_kernel) self.mamba_cache = MambaCacheManager(self.vllm_config, self.lm_head.weight.dtype, num_layers, *state_shape) From 19d54a61a837890c1098319e39d29388b6d3954e Mon Sep 17 00:00:00 2001 From: asafg Date: Wed, 23 Jul 2025 17:36:20 +0300 Subject: [PATCH 22/44] refactor: Added _create_mamba1_state_tensors Signed-off-by: asafg --- vllm/v1/worker/gpu_model_runner.py | 64 +++++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6e12059c3875..569afbc1473b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2820,18 +2820,26 @@ def _reshape_kv_cache_tensors( get_dtype_size(dtype)) state_tensors = [] storage_offset = 0 - for shape in kv_cache_spec.shapes: - target_shape = (num_blocks, *shape) - stride = torch.empty(target_shape).stride() - target_stride = (num_element_per_page, *stride[1:]) - tensor = torch.as_strided( - raw_tensor.view(dtype), - size=target_shape, - stride=target_stride, - storage_offset=storage_offset, + + if kv_cache_spec.mamba_type == "mamba1": + state_tensors = self._create_mamba1_state_tensors( + raw_tensor, dtype, kv_cache_spec.shapes, storage_offset ) - state_tensors.append(tensor) - storage_offset += stride[0] + state_tensors.extend(state_tensors) + else: + # Handle other mamba types + for shape in kv_cache_spec.shapes: + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset, + ) + storage_offset += stride[0] + state_tensors.append(tensor) kv_caches[layer_name] = state_tensors else: @@ -2843,6 +2851,7 @@ def _reshape_kv_cache_tensors( return kv_caches + def _verify_hybrid_attention_mamba_layout( self, kv_cache_config: KVCacheConfig, kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: @@ -3076,3 +3085,36 @@ def _build_encoder_only_attn_metadata( common_prefix_len=0, # No cascade for encoder common_attn_metadata=common_metadata, ) + + def _create_mamba1_state_tensors( + self, + raw_tensor: torch.Tensor, + dtype: torch.dtype, + shapes: tuple, + storage_offset: int + ) -> list[torch.Tensor]: + conv_state_shape, temporal_state_shape = shapes + num_sequences = len(self.seq_lens) + + conv_target_shape = (num_sequences, conv_state_shape[1], conv_state_shape[0]) + conv_stride = torch.empty(conv_target_shape).stride() + conv_state = torch.as_strided( + raw_tensor.view(dtype), + size=conv_target_shape, + stride=conv_stride, + storage_offset=storage_offset, + ).transpose(-1, -2) + + conv_elements = conv_target_shape[0] * conv_target_shape[1] * conv_target_shape[2] + storage_offset += conv_elements + + temporal_target_shape = (num_sequences, temporal_state_shape[0], temporal_state_shape[1]) + temporal_stride = torch.empty(temporal_target_shape).stride() + temporal_state = torch.as_strided( + raw_tensor.view(dtype), + size=temporal_target_shape, + stride=temporal_stride, + storage_offset=storage_offset, + ) + + return [conv_state, temporal_state] From e7d3e7d070cd3a99a4e66b6cced1545e05257080 Mon Sep 17 00:00:00 2001 From: asafg Date: Wed, 23 Jul 2025 17:37:56 +0300 Subject: [PATCH 23/44] fix: Lint Signed-off-by: asafg --- .../layers/mamba/mamba_utils.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 22 +++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 6d655d15edfc..de93dbb23304 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -14,7 +14,7 @@ def mamba1_state_shape( conv_kernel: int, ) -> tuple[tuple[int, int], tuple[int, int]]: conv_state_shape = (divide(intermediate_size, - tp_world_size), conv_kernel - 1) + tp_world_size), conv_kernel - 1) temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 569afbc1473b..6161c53d2ce2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2823,8 +2823,8 @@ def _reshape_kv_cache_tensors( if kv_cache_spec.mamba_type == "mamba1": state_tensors = self._create_mamba1_state_tensors( - raw_tensor, dtype, kv_cache_spec.shapes, storage_offset - ) + raw_tensor, dtype, kv_cache_spec.shapes, + storage_offset) state_tensors.extend(state_tensors) else: # Handle other mamba types @@ -2851,7 +2851,6 @@ def _reshape_kv_cache_tensors( return kv_caches - def _verify_hybrid_attention_mamba_layout( self, kv_cache_config: KVCacheConfig, kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: @@ -3087,16 +3086,13 @@ def _build_encoder_only_attn_metadata( ) def _create_mamba1_state_tensors( - self, - raw_tensor: torch.Tensor, - dtype: torch.dtype, - shapes: tuple, - storage_offset: int - ) -> list[torch.Tensor]: + self, raw_tensor: torch.Tensor, dtype: torch.dtype, shapes: tuple, + storage_offset: int) -> list[torch.Tensor]: conv_state_shape, temporal_state_shape = shapes num_sequences = len(self.seq_lens) - conv_target_shape = (num_sequences, conv_state_shape[1], conv_state_shape[0]) + conv_target_shape = (num_sequences, conv_state_shape[1], + conv_state_shape[0]) conv_stride = torch.empty(conv_target_shape).stride() conv_state = torch.as_strided( raw_tensor.view(dtype), @@ -3105,10 +3101,12 @@ def _create_mamba1_state_tensors( storage_offset=storage_offset, ).transpose(-1, -2) - conv_elements = conv_target_shape[0] * conv_target_shape[1] * conv_target_shape[2] + conv_elements = conv_target_shape[0] * conv_target_shape[ + 1] * conv_target_shape[2] storage_offset += conv_elements - temporal_target_shape = (num_sequences, temporal_state_shape[0], temporal_state_shape[1]) + temporal_target_shape = (num_sequences, temporal_state_shape[0], + temporal_state_shape[1]) temporal_stride = torch.empty(temporal_target_shape).stride() temporal_state = torch.as_strided( raw_tensor.view(dtype), From e20516ed4b668ae8eb935fcec2fb662c6e569564 Mon Sep 17 00:00:00 2001 From: asafg Date: Wed, 23 Jul 2025 20:54:51 +0300 Subject: [PATCH 24/44] test: Updated ssm tests to work in test_hybrid.py Signed-off-by: asafg --- .../models/language/generation/test_hybrid.py | 25 +++++++++++++------ vllm/model_executor/models/jamba.py | 2 +- vllm/model_executor/models/mamba.py | 22 +++++++++++----- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 2238924c1b50..de954ee40d65 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -53,6 +53,7 @@ ] V1_SUPPORTED_MODELS = [ + "state-spaces/mamba-130m-hf", "mistralai/Mamba-Codestral-7B-v0.1", "ibm-ai-platform/Bamba-9B-v1", "Zyphra/Zamba2-1.2B-instruct", @@ -97,14 +98,19 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: + enforce_eager = False with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS: - # required due to reorder_batch behaviour + if model in HYBRID_MODELS + SSM_MODELS: m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + if model in SSM_MODELS: + # Set to True until support in CUDA Graphs + enforce_eager = True + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - enable_prefix_caching=False) as vllm_model: + enable_prefix_caching=False, + enforce_eager=enforce_eager) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) else: @@ -371,13 +377,18 @@ def test_distributed_correctness( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model, tensor_parallel_size=1, - max_num_seqs=2) as vllm_model: + # Set enforce_eager=True until support in CUDA Graphs + with vllm_runner(model, + tensor_parallel_size=1, + max_num_seqs=2, + enforce_eager=True) as vllm_model: vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, tensor_parallel_size=2, - max_num_seqs=2) as vllm_model: + with vllm_runner(model, + tensor_parallel_size=2, + max_num_seqs=2, + enforce_eager=True) as vllm_model: vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index c4dc57a7069f..ca936c96a57d 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -522,7 +522,7 @@ def forward(self, num_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) state_shape = self.get_mamba_state_shape_from_config( - self.vllm_config, use_v1=False) + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, self.lm_head.weight.dtype, num_layers, *state_shape) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 8096757f6dac..6f8a398fa1b5 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -241,12 +241,8 @@ def forward(self, if self.mamba_cache is None: num_layers = self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba) - state_shape = MambaStateShapeCalculator.mamba1_state_shape( - tp_world_size=self.vllm_config.parallel_config. - tensor_parallel_size, - intermediate_size=self.config.intermediate_size, - state_size=self.config.state_size, - conv_kernel=self.config.conv_kernel) + state_shape = self.get_mamba_state_shape_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, self.lm_head.weight.dtype, num_layers, *state_shape) @@ -257,6 +253,20 @@ def forward(self, intermediate_tensors, inputs_embeds) return hidden_states + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[tuple[int, int], tuple[int, int]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + return MambaStateShapeCalculator.mamba1_state_shape( + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.intermediate_size, + state_size=hf_config.state_size, + conv_kernel=hf_config.conv_kernel) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs( From e0619a48d8673407c20be07ccd5a2e38fe9f4809 Mon Sep 17 00:00:00 2001 From: asafg Date: Thu, 24 Jul 2025 17:58:09 +0300 Subject: [PATCH 25/44] fix: Extra extend to state tensors Signed-off-by: asafg --- vllm/model_executor/models/mamba.py | 10 +++++----- vllm/v1/worker/gpu_model_runner.py | 1 - 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 6f8a398fa1b5..8a57db1cffde 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -253,7 +253,7 @@ def forward(self, intermediate_tensors, inputs_embeds) return hidden_states - + @classmethod def get_mamba_state_shape_from_config( cls, @@ -263,10 +263,10 @@ def get_mamba_state_shape_from_config( hf_config = vllm_config.model_config.hf_config return MambaStateShapeCalculator.mamba1_state_shape( - tp_world_size=parallel_config.tensor_parallel_size, - intermediate_size=hf_config.intermediate_size, - state_size=hf_config.state_size, - conv_kernel=hf_config.conv_kernel) + tp_world_size=parallel_config.tensor_parallel_size, + intermediate_size=hf_config.intermediate_size, + state_size=hf_config.state_size, + conv_kernel=hf_config.conv_kernel) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6161c53d2ce2..f1dbe167178a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2825,7 +2825,6 @@ def _reshape_kv_cache_tensors( state_tensors = self._create_mamba1_state_tensors( raw_tensor, dtype, kv_cache_spec.shapes, storage_offset) - state_tensors.extend(state_tensors) else: # Handle other mamba types for shape in kv_cache_spec.shapes: From d1c70630f48aeb5a3c38b1575ca843c4d2bbadd4 Mon Sep 17 00:00:00 2001 From: asafg Date: Thu, 24 Jul 2025 22:34:04 +0300 Subject: [PATCH 26/44] fix: Moved logic to create_mamba2_state_tensors Signed-off-by: asafg --- .../models/language/generation/test_hybrid.py | 14 ++--- vllm/v1/worker/gpu_model_runner.py | 55 ++++++++++++------- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index de954ee40d65..f62e1b7128d3 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -101,7 +101,7 @@ def test_models( enforce_eager = False with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - if model in HYBRID_MODELS + SSM_MODELS: + if model in HYBRID_MODELS: m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") if model in SSM_MODELS: # Set to True until support in CUDA Graphs @@ -378,17 +378,13 @@ def test_distributed_correctness( num_logprobs: int, ) -> None: # Set enforce_eager=True until support in CUDA Graphs - with vllm_runner(model, - tensor_parallel_size=1, - max_num_seqs=2, - enforce_eager=True) as vllm_model: + with vllm_runner(model, tensor_parallel_size=1, + max_num_seqs=2) as vllm_model: vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, - tensor_parallel_size=2, - max_num_seqs=2, - enforce_eager=True) as vllm_model: + with vllm_runner(model, tensor_parallel_size=2, + max_num_seqs=2) as vllm_model: vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f1dbe167178a..30a09cf78442 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2818,27 +2818,19 @@ def _reshape_kv_cache_tensors( dtype = kv_cache_spec.dtype num_element_per_page = (kv_cache_spec.page_size_bytes // get_dtype_size(dtype)) - state_tensors = [] - storage_offset = 0 if kv_cache_spec.mamba_type == "mamba1": state_tensors = self._create_mamba1_state_tensors( - raw_tensor, dtype, kv_cache_spec.shapes, - storage_offset) + raw_tensor=raw_tensor, + dtype=dtype, + shapes=kv_cache_spec.shapes) else: - # Handle other mamba types - for shape in kv_cache_spec.shapes: - target_shape = (num_blocks, *shape) - stride = torch.empty(target_shape).stride() - target_stride = (num_element_per_page, *stride[1:]) - tensor = torch.as_strided( - raw_tensor.view(dtype), - size=target_shape, - stride=target_stride, - storage_offset=storage_offset, - ) - storage_offset += stride[0] - state_tensors.append(tensor) + state_tensors = self._create_mamba2_state_tensors( + raw_tensor=raw_tensor, + shapes=kv_cache_spec.shapes, + num_blocks=num_blocks, + num_element_per_page=num_element_per_page, + dtype=dtype) kv_caches[layer_name] = state_tensors else: @@ -3084,9 +3076,10 @@ def _build_encoder_only_attn_metadata( common_attn_metadata=common_metadata, ) - def _create_mamba1_state_tensors( - self, raw_tensor: torch.Tensor, dtype: torch.dtype, shapes: tuple, - storage_offset: int) -> list[torch.Tensor]: + def _create_mamba1_state_tensors(self, raw_tensor: torch.Tensor, + dtype: torch.dtype, + shapes: tuple) -> list[torch.Tensor]: + storage_offset = 0 conv_state_shape, temporal_state_shape = shapes num_sequences = len(self.seq_lens) @@ -3115,3 +3108,25 @@ def _create_mamba1_state_tensors( ) return [conv_state, temporal_state] + + def _create_mamba2_state_tensors(self, raw_tensor: torch.Tensor, + shapes: tuple, num_blocks: int, + num_element_per_page: int, + dtype: torch.dtype) -> list[torch.Tensor]: + state_tensors = [] + storage_offset = 0 + + for shape in shapes: + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset, + ) + storage_offset += stride[0] + state_tensors.append(tensor) + + return state_tensors From c3577311bf2e40a1dfbae219c2e07e74e73ddfeb Mon Sep 17 00:00:00 2001 From: asafg Date: Thu, 24 Jul 2025 22:37:08 +0300 Subject: [PATCH 27/44] fix: Order in conv state shape Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index de93dbb23304..1b527ed6d859 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -42,8 +42,7 @@ def mamba2_state_shape( # contiguous along 'dim' axis conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) if not use_v1: - conv_state_shape = (divide(conv_dim, - tp_world_size), conv_kernel - 1) + conv_state_shape = conv_state_shape[1], conv_state_shape[0] # These are not TP-ed as they depend on A, dt_bias, D # - they are typically small From 8178700e434de91e6fb7cfb46570fb304b9948c1 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 27 Jul 2025 11:17:15 +0300 Subject: [PATCH 28/44] fix: Conflicts Signed-off-by: asafg --- vllm/v1/worker/gpu_model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 30a09cf78442..81ce5de5fc17 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2592,8 +2592,7 @@ def _initialize_single_attn_backend( "Non-Attention backend is not supported by V1 " "GPUModelRunner.") elif isinstance(kv_cache_spec, MambaSpec): - attn_backend_i = get_mamba_attn_backend( - kv_cache_spec.mamba_type) + attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type) else: raise ValueError( f"Unknown KV cache spec type: {type(kv_cache_spec)}") From 7525091c88d476f80dada52adb6cd9e227109f0b Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 3 Aug 2025 13:55:43 +0300 Subject: [PATCH 29/44] fix: Added transponse to mixer Signed-off-by: asafg --- .../layers/mamba/mamba_mixer.py | 2 +- .../layers/mamba/mamba_utils.py | 10 +- vllm/v1/worker/gpu_model_runner.py | 531 +++++++++--------- 3 files changed, 283 insertions(+), 260 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 07603401d675..8b5ae1fb8a63 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -175,7 +175,7 @@ def forward_cuda(self, query_start_loc = mamba1_metadata.query_start_loc state_indices_tensor = mamba1_metadata.state_indices_tensor self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0] + conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_state = mamba1_metadata.has_initial_states context_lens_tensor = mamba1_metadata.context_lens_tensor diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 1b527ed6d859..4cb20debb5e4 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm import envs from vllm.distributed import divide @@ -18,6 +19,11 @@ def mamba1_state_shape( temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) + + if envs.VLLM_USE_V1: + return (conv_state_shape[1], + conv_state_shape[0]), temporal_state_shape + return conv_state_shape, temporal_state_shape @classmethod @@ -34,8 +40,8 @@ def mamba2_state_shape( ) -> tuple[tuple[int, int], tuple[int, int, int]]: # if n_groups is not divisible by world_size, need to extend the shards # to ensure all groups needed by a head is sharded along with it - n_groups = (n_groups + - cls.extra_groups_for_head_shards(n_groups, tp_world_size)) + n_groups = n_groups + cls.extra_groups_for_head_shards( + n_groups, tp_world_size) # heads and n_groups are TP-ed conv_dim = intermediate_size + 2 * n_groups * state_size diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 81ce5de5fc17..0430af33ebca 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -87,8 +87,10 @@ else: xgr = LazyLoader("xgr", globals(), "xgrammar") xgr_torch_compile = LazyLoader( - "xgr_torch_compile", globals(), - "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") + "xgr_torch_compile", + globals(), + "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile", + ) logger = init_logger(__name__) @@ -112,6 +114,7 @@ def __init__( self.observability_config = vllm_config.observability_config from vllm.model_executor.models.utils import set_cpu_offload_max_bytes + set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -279,7 +282,8 @@ def __init__( (3, self.max_num_tokens + 1), dtype=torch.int64, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) self.mrope_positions_np = self.mrope_positions_cpu.numpy() # Only relevant for models using ALiBi (e.g, MPT) @@ -288,35 +292,45 @@ def __init__( self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, - device=self.device) + device=self.device, + ) # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(max(self.max_num_reqs + 1, - self.max_model_len, - self.max_num_tokens), - dtype=np.int64) + self.arange_np = np.arange( + max(self.max_num_reqs + 1, self.max_model_len, + self.max_num_tokens), + dtype=np.int64, + ) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. - self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.positions_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) + self.input_ids_cpu = torch.zeros( + self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) + self.positions_cpu = torch.zeros( + self.max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) self.positions_np = self.positions_cpu.numpy() - self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + self.query_start_loc_cpu = torch.zeros( + self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_lens_cpu = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + self.seq_lens_cpu = torch.zeros( + self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) self.seq_lens_np = self.seq_lens_cpu.numpy() # Layer pairings for cross-layer KV sharing. @@ -359,8 +373,7 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # Note: used for model runner override. def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties - """ + """Initialize attributes from torch.cuda.get_device_properties""" self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count @@ -421,16 +434,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params - if sampling_params and \ - sampling_params.sampling_type == SamplingType.RANDOM_SEED: + if (sampling_params and sampling_params.sampling_type + == SamplingType.RANDOM_SEED): generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: generator = None if pooling_params: - assert (task := pooling_params.task) is not None, ( - "You did not set `task` in the API") + assert (task := pooling_params.task + ) is not None, "You did not set `task` in the API" model = cast(VllmModelForPooling, self.model) to_update = model.pooler.get_pooling_updates(task) @@ -475,17 +488,18 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: hf_config = self.model_config.hf_config - self.requests[req_id].mrope_positions, \ - self.requests[req_id].mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - self.requests[req_id].prompt_token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + ( + self.requests[req_id].mrope_positions, + self.requests[req_id].mrope_position_delta, + ) = MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) req_ids_to_add.append(req_id) @@ -537,8 +551,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + self.input_batch.num_computed_tokens_cpu[ + req_index] = num_computed_tokens self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu @@ -555,8 +569,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, ()) if spec_token_ids: num_spec_tokens = len(spec_token_ids) start_index = self.input_batch.num_tokens_no_spec[req_index] @@ -634,9 +648,14 @@ def _get_cumsum_and_arange( def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, - Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata]]: + ) -> tuple[ + dict[str, Any], + bool, + torch.Tensor, + Optional[SpecDecodeMetadata], + np.ndarray, + Optional[CommonAttentionMetadata], + ]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -671,9 +690,11 @@ def _prepare_inputs( # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -690,10 +711,12 @@ def _prepare_inputs( # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) @@ -715,7 +738,8 @@ def _prepare_inputs( # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], - non_blocking=True) + non_blocking=True, + ) else: # Common case (1D positions) self.positions[:total_num_scheduled_tokens].copy_( @@ -792,9 +816,8 @@ def _prepare_inputs( # Prepare encoder attention metadata separately # (encoder layers are not in KV cache groups) if self.is_encoder_only_model: - common_attn_metadata, encoder_attn_metadata = \ - self._build_encoder_only_attn_metadata( - scheduler_output) + common_attn_metadata, encoder_attn_metadata = ( + self._build_encoder_only_attn_metadata(scheduler_output)) # Add encoder attention metadata for all encoder layers attention_layers = get_layers_from_vllm_config( @@ -831,15 +854,16 @@ def _prepare_inputs( causal=True, ) - if self.speculative_config and \ - spec_decode_common_attn_metadata is None: + if self.speculative_config and spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata = common_attn_metadata if isinstance(kv_cache_group_spec.kv_cache_spec, ChunkedLocalAttentionSpec): common_attn_metadata = make_local_attention_virtual_batches( kv_cache_group_spec.kv_cache_spec.attention_chunk_size, - common_attn_metadata, self.cache_config.block_size) + common_attn_metadata, + self.cache_config.block_size, + ) # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 @@ -853,10 +877,10 @@ def _prepare_inputs( builder, ) - attn_metadata_i = (builder.build( + attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - )) + ) fast_prefill_metadata = attn_metadata_i if (self.cache_config.kv_sharing_fast_prefill @@ -915,9 +939,14 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens, - spec_decode_common_attn_metadata) + return ( + attn_metadata, + attention_cuda_graphs, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens, + spec_decode_common_attn_metadata, + ) def _compute_cascade_attn_prefix_len( self, @@ -994,13 +1023,13 @@ def _compute_cascade_attn_prefix_len( # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or - (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None)) - use_local_attention = ( - isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) - or (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.attention_chunk_size is not None)) + use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None) + use_local_attention = isinstance( + kv_cache_spec, ChunkedLocalAttentionSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -1020,10 +1049,10 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): req = self.requests[req_id] assert req.mrope_positions is not None - num_computed_tokens = \ - self.input_batch.num_computed_tokens_cpu[index] - num_scheduled_tokens = \ - scheduler_output.num_scheduled_tokens[req_id] + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[ + index] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] num_prompt_tokens = len(req.prompt_token_ids) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: @@ -1044,8 +1073,10 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions_cpu[:, dst_start:dst_end] = \ - req.mrope_positions[:,src_start:src_end] + self.mrope_positions_cpu[:, dst_start: + dst_end] = req.mrope_positions[:, + src_start: + src_end] mrope_pos_ptr += prompt_part_len @@ -1229,7 +1260,8 @@ def _gather_mm_embeddings( start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) + num_encoder_tokens, + ) assert start_idx < end_idx assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] @@ -1323,8 +1355,8 @@ def apply_grammar_bitmask( num_spec_tokens = len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + i] = \ - grammar_bitmask[cumulative_index + i] + sorted_bitmask[logit_index + + i] = grammar_bitmask[cumulative_index + i] out_indices.append(logit_index + i) cumulative_index += 1 + num_spec_tokens grammar_bitmask = sorted_bitmask @@ -1343,20 +1375,21 @@ def apply_grammar_bitmask( ) def sync_and_slice_intermediate_tensors( - self, num_tokens: int, intermediate_tensors: IntermediateTensors, - sync_self: bool) -> IntermediateTensors: + self, + num_tokens: int, + intermediate_tensors: IntermediateTensors, + sync_self: bool, + ) -> IntermediateTensors: assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size - enabled_sp = self.compilation_config.pass_config. \ - enable_sequence_parallelism + enabled_sp = self.compilation_config.pass_config.enable_sequence_parallelism if enabled_sp: # When sequence parallelism is enabled, we always pad num_tokens # to be a multiple of tensor_parallel_size (tp) earlier assert num_tokens % tp == 0 - is_residual_scattered = tp > 1 and enabled_sp \ - and num_tokens % tp == 0 + is_residual_scattered = tp > 1 and enabled_sp and num_tokens % tp == 0 # When sequence parallelism is enabled, the "residual" tensor is sharded # across tensor parallel ranks, so each rank only needs its own slice. @@ -1364,15 +1397,13 @@ def sync_and_slice_intermediate_tensors( assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): is_scattered = k == "residual" and is_residual_scattered - copy_len = num_tokens // tp if is_scattered else \ - num_tokens + copy_len = num_tokens // tp if is_scattered else num_tokens self.intermediate_tensors[k][:copy_len].copy_( v[:copy_len], non_blocking=True) return IntermediateTensors({ - k: - v[:num_tokens // tp] - if k == "residual" and is_residual_scattered else v[:num_tokens] + k: (v[:num_tokens // tp] if k == "residual" + and is_residual_scattered else v[:num_tokens]) for k, v in self.intermediate_tensors.items() }) @@ -1426,10 +1457,10 @@ def _pool( num_scheduled_tokens_np: np.ndarray, kv_connector_output: Optional[KVConnectorOutput], ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs ==\ - len(self.input_batch.pooling_params), \ - "Either all or none of the requests in" \ - " a batch must be pooling request" + assert self.input_batch.num_reqs == len( + self.input_batch.pooling_params), ( + "Either all or none of the requests in" + " a batch must be pooling request") extracted_hidden_states = list( torch.split(hidden_states[:num_scheduled_tokens], @@ -1478,11 +1509,14 @@ def execute_model( self.vllm_config) # Prepare the decoder inputs. - (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens_np, - spec_decode_common_attn_metadata) = ( - self._prepare_inputs(scheduler_output)) - + ( + attn_metadata, + attention_cuda_graphs, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + spec_decode_common_attn_metadata, + ) = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1495,8 +1529,8 @@ def execute_model( # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: + if (self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1): num_input_tokens = round_up(num_scheduled_tokens, tp_size) else: num_input_tokens = num_scheduled_tokens @@ -1587,9 +1621,9 @@ def execute_model( # to make sure we are synced across pp ranks # TODO: Support overlapping mirco-batches # https://github.com/vllm-project/vllm/issues/18019 - broadcast_pp_output = \ - self.parallel_config.distributed_executor_backend \ - == "external_launcher" and len(get_pp_group().ranks) > 0 + broadcast_pp_output = ( + self.parallel_config.distributed_executor_backend + == "external_launcher" and len(get_pp_group().ranks) > 0) if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. assert isinstance(hidden_states, IntermediateTensors) @@ -1607,9 +1641,9 @@ def execute_model( sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: - model_output_broadcast_data = { + model_output_broadcast_data = ({ "logits": logits.contiguous(), - } if logits is not None else {} + } if logits is not None else {}) model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) assert model_output_broadcast_data is not None @@ -1677,8 +1711,8 @@ def execute_model( # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None + logprobs_lists = (logprobs_tensors.tolists() + if logprobs_tensors is not None else None) # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( @@ -1833,8 +1867,7 @@ def propose_draft_token_ids( ] num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, dtype=torch.int32) - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( + common_attn_metadata, token_indices = self.drafter.prepare_inputs( common_attn_metadata, num_rejected_tokens_cpu) target_token_ids = self.input_ids[token_indices] @@ -1899,9 +1932,9 @@ def propose_ngram_draft_token_ids( def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ - f"Allowed configs: {allowed_config_names}" + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " + f"Allowed configs: {allowed_config_names}") config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -1914,6 +1947,7 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) if eep_scale_up: from vllm.distributed.parallel_state import get_ep_group + num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") @@ -1922,15 +1956,15 @@ def load_model(self, eep_scale_up: bool = False) -> None: group_src=0) num_local_physical_experts = int(num_local_physical_experts.item()) new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = ( - EplbState.recv_state()) + global_expert_load, old_global_expert_indices = EplbState.recv_state( + ) num_logical_experts = global_expert_load.shape[1] self.parallel_config.num_redundant_experts = ( num_local_physical_experts * new_ep_size - num_logical_experts) assert old_global_expert_indices.shape[ 1] % num_local_physical_experts == 0 - old_ep_size = old_global_expert_indices.shape[ - 1] // num_local_physical_experts + old_ep_size = (old_global_expert_indices.shape[1] // + num_local_physical_experts) rank_mapping = { old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size) @@ -1947,11 +1981,13 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config) if self.lora_config: - self.model = self.load_lora_model(self.model, - self.model_config, - self.scheduler_config, - self.lora_config, - self.device) + self.model = self.load_lora_model( + self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device, + ) if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) @@ -1960,9 +1996,11 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model.get_eagle3_aux_hidden_state_layers()) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load) + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load, + ) prepare_communication_buffer_for_model(self.model) if is_mixture_of_experts( @@ -1990,8 +2028,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: backend=backend) def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ - "Cannot reload weights before model is loaded." + assert (getattr(self, "model", None) + is not None), "Cannot reload weights before model is loaded." model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") model_loader.load_weights(self.model, model_config=self.model_config) @@ -2142,7 +2180,8 @@ def rand_input_ids() -> torch.Tensor: self.input_ids, low=0, high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype) + dtype=input_ids.dtype, + ) logger.debug("Randomizing dummy data for DP Rank") input_ids.copy_(rand_input_ids()[:input_ids.size(0)], @@ -2204,7 +2243,8 @@ def _dummy_run( kv_cache_group_id].get_device_tensor()[:num_reqs], slot_mapping=self.input_batch. block_table[kv_cache_group_id].slot_mapping[:num_tokens], - causal=True) + causal=True, + ) attn_metadata_i = self.attn_metadata_builders[ kv_cache_group_id].build_for_cudagraph_capture( @@ -2238,16 +2278,21 @@ def _dummy_run( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, - device=self.device)) + device=self.device, + )) intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - with self.maybe_randomize_inputs(input_ids), set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + with ( + self.maybe_randomize_inputs(input_ids), + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + ), + ): outputs = model( input_ids=input_ids, positions=positions, @@ -2319,7 +2364,7 @@ def _dummy_sampler_run( sampler_output = self.sampler(logits=logits, sampling_metadata=dummy_metadata) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up sampler with " f"{num_reqs} dummy requests. Please try lowering " @@ -2397,7 +2442,7 @@ def _dummy_pooler_run_task( return model.pooler(hidden_states=hidden_states_list, pooling_metadata=dummy_metadata) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " @@ -2431,8 +2476,9 @@ def profile_run(self) -> None: # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - max_tokens_by_modality_dict = self.mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(self.model_config) + max_tokens_by_modality_dict = ( + self.mm_registry.get_max_tokens_per_item_by_nonzero_modality( + self.model_config)) dummy_data_modality, max_tokens_per_mm_item = max( max_tokens_by_modality_dict.items(), key=lambda item: item[1]) @@ -2441,8 +2487,7 @@ def profile_run(self) -> None: encoder_budget = min(self.max_num_encoder_input_tokens, self.encoder_cache_size) - max_num_mm_items_encoder_budget = encoder_budget // \ - max_tokens_per_mm_item + max_num_mm_items_encoder_budget = encoder_budget // max_tokens_per_mm_item # Check how many items of this modality can be supported by # the decoder budget. @@ -2452,8 +2497,7 @@ def profile_run(self) -> None: # NOTE: We do not consider max_num_batched_tokens on purpose # because the multimodal embeddings can be generated in advance # and chunked prefilled. - max_num_mm_items_decoder_budget = self.max_num_reqs * \ - max_mm_items_per_req + max_num_mm_items_decoder_budget = self.max_num_reqs * max_mm_items_per_req max_num_mm_items = max( 1, @@ -2463,7 +2507,10 @@ def profile_run(self) -> None: logger.info( "Encoder cache will be initialized with a budget of %s tokens," " and profiled with %s %s items of the maximum feature size.", - encoder_budget, max_num_mm_items, dummy_data_modality) + encoder_budget, + max_num_mm_items, + dummy_data_modality, + ) # Create dummy batch of multimodal inputs. dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( @@ -2495,8 +2542,8 @@ def profile_run(self) -> None: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers - hidden_states, last_hidden_states \ - = self._dummy_run(self.max_num_tokens, is_profile=True) + hidden_states, last_hidden_states = self._dummy_run( + self.max_num_tokens, is_profile=True) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -2514,7 +2561,9 @@ def capture_model(self) -> None: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " "set -O %s and ensure `use_cudagraph` was not manually set to " - "False", CompilationLevel.PIECEWISE) + "False", + CompilationLevel.PIECEWISE, + ) return compilation_counter.num_gpu_runner_capture_triggers += 1 @@ -2548,7 +2597,8 @@ def freeze_gc(): compilation_cases = tqdm( list(compilation_cases), disable=not self.load_config.use_tqdm_on_load, - desc="Capturing CUDA graph shapes") + desc="Capturing CUDA graph shapes", + ) for num_tokens in compilation_cases: # We skip EPLB here since we don't want to record dummy metrics for _ in range( @@ -2565,8 +2615,11 @@ def freeze_gc(): elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / (1 << 30)) + logger.info( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + ) def _initialize_single_attn_backend( self, kv_cache_spec: KVCacheSpec, layer_names: list[str] @@ -2626,16 +2679,16 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_backends) == 0 and len( - self.attn_metadata_builders - ) == 0, "Attention backends are already initialized" + assert (len(self.attn_backends) == 0 + and len(self.attn_metadata_builders) + == 0), "Attention backends are already initialized" for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec attn_backend_i, attn_metadata_builder_i = ( self._initialize_single_attn_backend( - kv_cache_spec, kv_cache_group_spec.layer_names)) + kv_cache_spec, kv_cache_group_spec.layer_names))) self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) @@ -2657,17 +2710,20 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: "window attention is not supported for encoder-only models" attn_specs.append( - FullAttentionSpec(block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla)) + FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla, + )) else: raise ValueError("Expected only encoder-only layers") if len(attn_specs) > 0: - assert len(attn_specs) == len(attn_layers), \ - "All or none of the layers are expected to be encoder-only" + assert len(attn_specs) == len( + attn_layers + ), "All or none of the layers are expected to be encoder-only" attn_backend, attn_metadata_builder = ( self._initialize_single_attn_backend(attn_specs[0], @@ -2740,7 +2796,7 @@ def _allocate_kv_cache_tensors( Returns: dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: tensor = torch.zeros(kv_cache_tensor.size, @@ -2780,13 +2836,16 @@ def _reshape_kv_cache_tensors( for layer_name in kv_cache_group_spec.layer_names: raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) + num_blocks = raw_tensor.numel( + ) // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) dtype = kv_cache_spec.dtype try: kv_cache_stride_order = self.attn_backends[ @@ -2808,9 +2867,9 @@ def _reshape_kv_cache_tensors( kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = kv_cache_raw_tensors[ - layer_name].view(dtype).view(kv_cache_shape).permute( - *inv_order) + kv_caches[layer_name] = ( + kv_cache_raw_tensors[layer_name].view(dtype).view( + kv_cache_shape).permute(*inv_order)) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] @@ -2818,18 +2877,20 @@ def _reshape_kv_cache_tensors( num_element_per_page = (kv_cache_spec.page_size_bytes // get_dtype_size(dtype)) - if kv_cache_spec.mamba_type == "mamba1": - state_tensors = self._create_mamba1_state_tensors( - raw_tensor=raw_tensor, - dtype=dtype, - shapes=kv_cache_spec.shapes) - else: - state_tensors = self._create_mamba2_state_tensors( - raw_tensor=raw_tensor, - shapes=kv_cache_spec.shapes, - num_blocks=num_blocks, - num_element_per_page=num_element_per_page, - dtype=dtype) + state_tensors = [] + storage_offset = 0 + for shape in kv_cache_spec.shapes: + target_shape = (num_blocks, *shape) + stride = torch.empty(target_shape).stride() + target_stride = (num_element_per_page, *stride[1:]) + tensor = torch.as_strided( + raw_tensor.view(dtype), + size=target_shape, + stride=target_stride, + storage_offset=storage_offset, + ) + state_tensors.append(tensor) + storage_offset += stride[0] kv_caches[layer_name] = state_tensors else: @@ -2842,8 +2903,10 @@ def _reshape_kv_cache_tensors( return kv_caches def _verify_hybrid_attention_mamba_layout( - self, kv_cache_config: KVCacheConfig, - kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: + self, + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor], + ) -> None: """ Verify that the KV cache memory layout is compatible for models with both attention and mamba KV cache groups. @@ -2858,12 +2921,15 @@ def _verify_hybrid_attention_mamba_layout( kv_cache_spec = kv_cache_group_spec.kv_cache_spec for layer_name in kv_cache_group_spec.layer_names: raw_tensor = kv_cache_raw_tensors[layer_name] - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) + num_blocks = raw_tensor.numel( + ) // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) if kv_cache_shape[0] != num_blocks or kv_cache_shape[ 1] != 2: raise ValueError( @@ -2971,10 +3037,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, - use_mla=use_mla) + use_mla=use_mla, + ) assert not use_local_attention, ( "attention module can not be with ", - "both local attention and sliding window") + "both local attention and sliding window", + ) elif use_local_attention: kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( block_size=block_size, @@ -2982,16 +3050,20 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, attention_chunk_size=self.attention_chunk_size, - use_mla=use_mla) + use_mla=use_mla, + ) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): + use_mla=use_mla, + ) + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: @@ -3010,8 +3082,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) + page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. @@ -3021,13 +3092,14 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, block_size=max_model_len, page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type) + mamba_type=mamba_module.mamba_type, + ) return kv_cache_spec def _build_encoder_only_attn_metadata( - self, scheduler_output: "SchedulerOutput") -> \ - tuple[CommonAttentionMetadata, Any]: + self, scheduler_output: "SchedulerOutput" + ) -> tuple[CommonAttentionMetadata, Any]: """Prepare encoder attention metadata for encoder-only models. Args: @@ -3074,58 +3146,3 @@ def _build_encoder_only_attn_metadata( common_prefix_len=0, # No cascade for encoder common_attn_metadata=common_metadata, ) - - def _create_mamba1_state_tensors(self, raw_tensor: torch.Tensor, - dtype: torch.dtype, - shapes: tuple) -> list[torch.Tensor]: - storage_offset = 0 - conv_state_shape, temporal_state_shape = shapes - num_sequences = len(self.seq_lens) - - conv_target_shape = (num_sequences, conv_state_shape[1], - conv_state_shape[0]) - conv_stride = torch.empty(conv_target_shape).stride() - conv_state = torch.as_strided( - raw_tensor.view(dtype), - size=conv_target_shape, - stride=conv_stride, - storage_offset=storage_offset, - ).transpose(-1, -2) - - conv_elements = conv_target_shape[0] * conv_target_shape[ - 1] * conv_target_shape[2] - storage_offset += conv_elements - - temporal_target_shape = (num_sequences, temporal_state_shape[0], - temporal_state_shape[1]) - temporal_stride = torch.empty(temporal_target_shape).stride() - temporal_state = torch.as_strided( - raw_tensor.view(dtype), - size=temporal_target_shape, - stride=temporal_stride, - storage_offset=storage_offset, - ) - - return [conv_state, temporal_state] - - def _create_mamba2_state_tensors(self, raw_tensor: torch.Tensor, - shapes: tuple, num_blocks: int, - num_element_per_page: int, - dtype: torch.dtype) -> list[torch.Tensor]: - state_tensors = [] - storage_offset = 0 - - for shape in shapes: - target_shape = (num_blocks, *shape) - stride = torch.empty(target_shape).stride() - target_stride = (num_element_per_page, *stride[1:]) - tensor = torch.as_strided( - raw_tensor.view(dtype), - size=target_shape, - stride=target_stride, - storage_offset=storage_offset, - ) - storage_offset += stride[0] - state_tensors.append(tensor) - - return state_tensors From 0ec3208b545c49e39cf2ecbfa3618ae827f85c2f Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 3 Aug 2025 14:55:30 +0300 Subject: [PATCH 30/44] feat: Updated selective scan fwd to work with strides Signed-off-by: asafg --- csrc/mamba/mamba_ssm/selective_scan.h | 3 +++ csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 12 +++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 563d2fe4ef65..13c6178941cf 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -45,6 +45,9 @@ struct SSMParamsBase { index_t out_d_stride; index_t out_z_batch_stride; index_t out_z_d_stride; + index_t ssm_states_batch_stride; + index_t ssm_states_dim_stride; + index_t ssm_states_dstate_stride; // Common data pointers. void *__restrict__ A_ptr; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 5766fbab4e87..6485fbaa1d25 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -132,7 +132,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; - input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate; + input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + + cache_index * params.ssm_states_batch_stride + + dim_id * kNRows * params.ssm_states_dim_stride; float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { @@ -248,7 +250,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } // Initialize running total - scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0); + scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0); SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( @@ -259,7 +261,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (threadIdx.x == 0) { smem_running_prefix[state_idx] = prefix_op.running_prefix; if (chunk == n_chunks - 1) { - ssm_states[state_idx] = input_t(prefix_op.running_prefix.y); + ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y); } } #pragma unroll @@ -481,6 +483,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.out_batch_stride = out.stride(1); params.out_d_stride = out.stride(0); + params.ssm_states_batch_stride = ssm_states.stride(0); + params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dstate_stride = ssm_states.stride(2); + } else{ if (!is_variable_B) { From 611a771297da2ee7d1086452b719a9246e479694 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 3 Aug 2025 15:07:20 +0300 Subject: [PATCH 31/44] fix: Conflicted changes in gpu_model_runner Signed-off-by: asafg --- vllm/v1/worker/gpu_model_runner.py | 443 ++++++++++++----------------- 1 file changed, 185 insertions(+), 258 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0430af33ebca..871ca97cb555 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -45,7 +45,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) -from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, @@ -87,10 +86,8 @@ else: xgr = LazyLoader("xgr", globals(), "xgrammar") xgr_torch_compile = LazyLoader( - "xgr_torch_compile", - globals(), - "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile", - ) + "xgr_torch_compile", globals(), + "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") logger = init_logger(__name__) @@ -114,7 +111,6 @@ def __init__( self.observability_config = vllm_config.observability_config from vllm.model_executor.models.utils import set_cpu_offload_max_bytes - set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -282,8 +278,7 @@ def __init__( (3, self.max_num_tokens + 1), dtype=torch.int64, device="cpu", - pin_memory=self.pin_memory, - ) + pin_memory=self.pin_memory) self.mrope_positions_np = self.mrope_positions_cpu.numpy() # Only relevant for models using ALiBi (e.g, MPT) @@ -292,45 +287,35 @@ def __init__( self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, - device=self.device, - ) + device=self.device) # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange( - max(self.max_num_reqs + 1, self.max_model_len, - self.max_num_tokens), - dtype=np.int64, - ) + self.arange_np = np.arange(max(self.max_num_reqs + 1, + self.max_model_len, + self.max_num_tokens), + dtype=np.int64) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. - self.input_ids_cpu = torch.zeros( - self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) - self.positions_cpu = torch.zeros( - self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory, - ) + self.input_ids_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.positions_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() - self.query_start_loc_cpu = torch.zeros( - self.max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_lens_cpu = torch.zeros( - self.max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) + self.seq_lens_cpu = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() # Layer pairings for cross-layer KV sharing. @@ -373,7 +358,8 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # Note: used for model runner override. def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties""" + """Initialize attributes from torch.cuda.get_device_properties + """ self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count @@ -434,16 +420,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params - if (sampling_params and sampling_params.sampling_type - == SamplingType.RANDOM_SEED): + if sampling_params and \ + sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: generator = None if pooling_params: - assert (task := pooling_params.task - ) is not None, "You did not set `task` in the API" + assert (task := pooling_params.task) is not None, ( + "You did not set `task` in the API") model = cast(VllmModelForPooling, self.model) to_update = model.pooler.get_pooling_updates(task) @@ -488,18 +474,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: hf_config = self.model_config.hf_config - ( - self.requests[req_id].mrope_positions, - self.requests[req_id].mrope_position_delta, - ) = MRotaryEmbedding.get_input_positions_tensor( - self.requests[req_id].prompt_token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + self.requests[req_id].mrope_positions, \ + self.requests[req_id].mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + self.requests[req_id].prompt_token_ids, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) req_ids_to_add.append(req_id) @@ -551,8 +536,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[ - req_index] = num_computed_tokens + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens) self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu @@ -569,8 +554,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, ()) + spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) if spec_token_ids: num_spec_tokens = len(spec_token_ids) start_index = self.input_batch.num_tokens_no_spec[req_index] @@ -648,14 +633,9 @@ def _get_cumsum_and_arange( def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[ - dict[str, Any], - bool, - torch.Tensor, - Optional[SpecDecodeMetadata], - np.ndarray, - Optional[CommonAttentionMetadata], - ]: + ) -> tuple[dict[str, + Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], + np.ndarray, Optional[CommonAttentionMetadata]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -690,11 +670,9 @@ def _prepare_inputs( # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np, - ) + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -711,12 +689,10 @@ def _prepare_inputs( # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select( - self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens], - ) + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) @@ -738,8 +714,7 @@ def _prepare_inputs( # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], - non_blocking=True, - ) + non_blocking=True) else: # Common case (1D positions) self.positions[:total_num_scheduled_tokens].copy_( @@ -816,8 +791,9 @@ def _prepare_inputs( # Prepare encoder attention metadata separately # (encoder layers are not in KV cache groups) if self.is_encoder_only_model: - common_attn_metadata, encoder_attn_metadata = ( - self._build_encoder_only_attn_metadata(scheduler_output)) + common_attn_metadata, encoder_attn_metadata = \ + self._build_encoder_only_attn_metadata( + scheduler_output) # Add encoder attention metadata for all encoder layers attention_layers = get_layers_from_vllm_config( @@ -854,16 +830,15 @@ def _prepare_inputs( causal=True, ) - if self.speculative_config and spec_decode_common_attn_metadata is None: + if self.speculative_config and \ + spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata = common_attn_metadata if isinstance(kv_cache_group_spec.kv_cache_spec, ChunkedLocalAttentionSpec): common_attn_metadata = make_local_attention_virtual_batches( kv_cache_group_spec.kv_cache_spec.attention_chunk_size, - common_attn_metadata, - self.cache_config.block_size, - ) + common_attn_metadata, self.cache_config.block_size) # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 @@ -877,10 +852,10 @@ def _prepare_inputs( builder, ) - attn_metadata_i = builder.build( + attn_metadata_i = (builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - ) + )) fast_prefill_metadata = attn_metadata_i if (self.cache_config.kv_sharing_fast_prefill @@ -939,14 +914,9 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return ( - attn_metadata, - attention_cuda_graphs, - logits_indices, - spec_decode_metadata, - num_scheduled_tokens, - spec_decode_common_attn_metadata, - ) + return (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata, num_scheduled_tokens, + spec_decode_common_attn_metadata) def _compute_cascade_attn_prefix_len( self, @@ -1023,13 +993,13 @@ def _compute_cascade_attn_prefix_len( # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size) - use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or ( - isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None) - use_local_attention = isinstance( - kv_cache_spec, ChunkedLocalAttentionSpec) or ( - isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.attention_chunk_size is not None) + use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or + (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None)) + use_local_attention = ( + isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) + or (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None)) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -1049,10 +1019,10 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): req = self.requests[req_id] assert req.mrope_positions is not None - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[ - index] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] num_prompt_tokens = len(req.prompt_token_ids) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: @@ -1073,10 +1043,8 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions_cpu[:, dst_start: - dst_end] = req.mrope_positions[:, - src_start: - src_end] + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + req.mrope_positions[:,src_start:src_end] mrope_pos_ptr += prompt_part_len @@ -1260,8 +1228,7 @@ def _gather_mm_embeddings( start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens, - ) + num_encoder_tokens) assert start_idx < end_idx assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] @@ -1355,8 +1322,8 @@ def apply_grammar_bitmask( num_spec_tokens = len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + - i] = grammar_bitmask[cumulative_index + i] + sorted_bitmask[logit_index + i] = \ + grammar_bitmask[cumulative_index + i] out_indices.append(logit_index + i) cumulative_index += 1 + num_spec_tokens grammar_bitmask = sorted_bitmask @@ -1375,21 +1342,20 @@ def apply_grammar_bitmask( ) def sync_and_slice_intermediate_tensors( - self, - num_tokens: int, - intermediate_tensors: IntermediateTensors, - sync_self: bool, - ) -> IntermediateTensors: + self, num_tokens: int, intermediate_tensors: IntermediateTensors, + sync_self: bool) -> IntermediateTensors: assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size - enabled_sp = self.compilation_config.pass_config.enable_sequence_parallelism + enabled_sp = self.compilation_config.pass_config. \ + enable_sequence_parallelism if enabled_sp: # When sequence parallelism is enabled, we always pad num_tokens # to be a multiple of tensor_parallel_size (tp) earlier assert num_tokens % tp == 0 - is_residual_scattered = tp > 1 and enabled_sp and num_tokens % tp == 0 + is_residual_scattered = tp > 1 and enabled_sp \ + and num_tokens % tp == 0 # When sequence parallelism is enabled, the "residual" tensor is sharded # across tensor parallel ranks, so each rank only needs its own slice. @@ -1397,13 +1363,15 @@ def sync_and_slice_intermediate_tensors( assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): is_scattered = k == "residual" and is_residual_scattered - copy_len = num_tokens // tp if is_scattered else num_tokens + copy_len = num_tokens // tp if is_scattered else \ + num_tokens self.intermediate_tensors[k][:copy_len].copy_( v[:copy_len], non_blocking=True) return IntermediateTensors({ - k: (v[:num_tokens // tp] if k == "residual" - and is_residual_scattered else v[:num_tokens]) + k: + v[:num_tokens // tp] + if k == "residual" and is_residual_scattered else v[:num_tokens] for k, v in self.intermediate_tensors.items() }) @@ -1509,14 +1477,11 @@ def execute_model( self.vllm_config) # Prepare the decoder inputs. - ( - attn_metadata, - attention_cuda_graphs, - logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np, - spec_decode_common_attn_metadata, - ) = self._prepare_inputs(scheduler_output) + (attn_metadata, attention_cuda_graphs, logits_indices, + spec_decode_metadata, num_scheduled_tokens_np, + spec_decode_common_attn_metadata) = ( + self._prepare_inputs(scheduler_output)) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1529,8 +1494,8 @@ def execute_model( # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if (self.compilation_config.pass_config.enable_sequence_parallelism - and tp_size > 1): + if self.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: num_input_tokens = round_up(num_scheduled_tokens, tp_size) else: num_input_tokens = num_scheduled_tokens @@ -1621,9 +1586,9 @@ def execute_model( # to make sure we are synced across pp ranks # TODO: Support overlapping mirco-batches # https://github.com/vllm-project/vllm/issues/18019 - broadcast_pp_output = ( - self.parallel_config.distributed_executor_backend - == "external_launcher" and len(get_pp_group().ranks) > 0) + broadcast_pp_output = \ + self.parallel_config.distributed_executor_backend \ + == "external_launcher" and len(get_pp_group().ranks) > 0 if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. assert isinstance(hidden_states, IntermediateTensors) @@ -1641,9 +1606,9 @@ def execute_model( sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: - model_output_broadcast_data = ({ + model_output_broadcast_data = { "logits": logits.contiguous(), - } if logits is not None else {}) + } if logits is not None else {} model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) assert model_output_broadcast_data is not None @@ -1867,7 +1832,8 @@ def propose_draft_token_ids( ] num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, dtype=torch.int32) - common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( common_attn_metadata, num_rejected_tokens_cpu) target_token_ids = self.input_ids[token_indices] @@ -1932,9 +1898,9 @@ def propose_ngram_draft_token_ids( def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, ( - f"Config `{config_name}` not supported. " - f"Allowed configs: {allowed_config_names}") + assert config_name in allowed_config_names, \ + f"Config `{config_name}` not supported. " \ + f"Allowed configs: {allowed_config_names}" config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -1947,7 +1913,6 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) if eep_scale_up: from vllm.distributed.parallel_state import get_ep_group - num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") @@ -1956,15 +1921,15 @@ def load_model(self, eep_scale_up: bool = False) -> None: group_src=0) num_local_physical_experts = int(num_local_physical_experts.item()) new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = EplbState.recv_state( - ) + global_expert_load, old_global_expert_indices = ( + EplbState.recv_state()) num_logical_experts = global_expert_load.shape[1] self.parallel_config.num_redundant_experts = ( num_local_physical_experts * new_ep_size - num_logical_experts) assert old_global_expert_indices.shape[ 1] % num_local_physical_experts == 0 - old_ep_size = (old_global_expert_indices.shape[1] // - num_local_physical_experts) + old_ep_size = old_global_expert_indices.shape[ + 1] // num_local_physical_experts rank_mapping = { old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size) @@ -1981,13 +1946,11 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config) if self.lora_config: - self.model = self.load_lora_model( - self.model, - self.model_config, - self.scheduler_config, - self.lora_config, - self.device, - ) + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) @@ -1996,11 +1959,9 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model.get_eagle3_aux_hidden_state_layers()) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info( - "Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load, - ) + logger.info("Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load) prepare_communication_buffer_for_model(self.model) if is_mixture_of_experts( @@ -2028,8 +1989,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: backend=backend) def reload_weights(self) -> None: - assert (getattr(self, "model", None) - is not None), "Cannot reload weights before model is loaded." + assert getattr(self, "model", None) is not None, \ + "Cannot reload weights before model is loaded." model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") model_loader.load_weights(self.model, model_config=self.model_config) @@ -2180,8 +2141,7 @@ def rand_input_ids() -> torch.Tensor: self.input_ids, low=0, high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype, - ) + dtype=input_ids.dtype) logger.debug("Randomizing dummy data for DP Rank") input_ids.copy_(rand_input_ids()[:input_ids.size(0)], @@ -2243,8 +2203,7 @@ def _dummy_run( kv_cache_group_id].get_device_tensor()[:num_reqs], slot_mapping=self.input_batch. block_table[kv_cache_group_id].slot_mapping[:num_tokens], - causal=True, - ) + causal=True) attn_metadata_i = self.attn_metadata_builders[ kv_cache_group_id].build_for_cudagraph_capture( @@ -2278,21 +2237,16 @@ def _dummy_run( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, - device=self.device, - )) + device=self.device)) intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - with ( - self.maybe_randomize_inputs(input_ids), - set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - ), - ): + with self.maybe_randomize_inputs(input_ids), set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): outputs = model( input_ids=input_ids, positions=positions, @@ -2364,7 +2318,7 @@ def _dummy_sampler_run( sampler_output = self.sampler(logits=logits, sampling_metadata=dummy_metadata) except RuntimeError as e: - if "out of memory" in str(e): + if 'out of memory' in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up sampler with " f"{num_reqs} dummy requests. Please try lowering " @@ -2442,7 +2396,7 @@ def _dummy_pooler_run_task( return model.pooler(hidden_states=hidden_states_list, pooling_metadata=dummy_metadata) except RuntimeError as e: - if "out of memory" in str(e): + if 'out of memory' in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " @@ -2476,9 +2430,8 @@ def profile_run(self) -> None: # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. - max_tokens_by_modality_dict = ( - self.mm_registry.get_max_tokens_per_item_by_nonzero_modality( - self.model_config)) + max_tokens_by_modality_dict = self.mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(self.model_config) dummy_data_modality, max_tokens_per_mm_item = max( max_tokens_by_modality_dict.items(), key=lambda item: item[1]) @@ -2487,7 +2440,8 @@ def profile_run(self) -> None: encoder_budget = min(self.max_num_encoder_input_tokens, self.encoder_cache_size) - max_num_mm_items_encoder_budget = encoder_budget // max_tokens_per_mm_item + max_num_mm_items_encoder_budget = encoder_budget // \ + max_tokens_per_mm_item # Check how many items of this modality can be supported by # the decoder budget. @@ -2497,7 +2451,8 @@ def profile_run(self) -> None: # NOTE: We do not consider max_num_batched_tokens on purpose # because the multimodal embeddings can be generated in advance # and chunked prefilled. - max_num_mm_items_decoder_budget = self.max_num_reqs * max_mm_items_per_req + max_num_mm_items_decoder_budget = self.max_num_reqs * \ + max_mm_items_per_req max_num_mm_items = max( 1, @@ -2507,10 +2462,7 @@ def profile_run(self) -> None: logger.info( "Encoder cache will be initialized with a budget of %s tokens," " and profiled with %s %s items of the maximum feature size.", - encoder_budget, - max_num_mm_items, - dummy_data_modality, - ) + encoder_budget, max_num_mm_items, dummy_data_modality) # Create dummy batch of multimodal inputs. dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( @@ -2542,8 +2494,8 @@ def profile_run(self) -> None: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers - hidden_states, last_hidden_states = self._dummy_run( - self.max_num_tokens, is_profile=True) + hidden_states, last_hidden_states \ + = self._dummy_run(self.max_num_tokens, is_profile=True) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -2561,9 +2513,7 @@ def capture_model(self) -> None: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " "set -O %s and ensure `use_cudagraph` was not manually set to " - "False", - CompilationLevel.PIECEWISE, - ) + "False", CompilationLevel.PIECEWISE) return compilation_counter.num_gpu_runner_capture_triggers += 1 @@ -2597,8 +2547,7 @@ def freeze_gc(): compilation_cases = tqdm( list(compilation_cases), disable=not self.load_config.use_tqdm_on_load, - desc="Capturing CUDA graph shapes", - ) + desc="Capturing CUDA graph shapes") for num_tokens in compilation_cases: # We skip EPLB here since we don't want to record dummy metrics for _ in range( @@ -2615,11 +2564,8 @@ def freeze_gc(): elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. - logger.info( - "Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, - cuda_graph_size / (1 << 30), - ) + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / (1 << 30)) def _initialize_single_attn_backend( self, kv_cache_spec: KVCacheSpec, layer_names: list[str] @@ -2679,16 +2625,16 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert (len(self.attn_backends) == 0 - and len(self.attn_metadata_builders) - == 0), "Attention backends are already initialized" + assert len(self.attn_backends) == 0 and len( + self.attn_metadata_builders + ) == 0, "Attention backends are already initialized" for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec attn_backend_i, attn_metadata_builder_i = ( self._initialize_single_attn_backend( - kv_cache_spec, kv_cache_group_spec.layer_names))) + kv_cache_spec, kv_cache_group_spec.layer_names)) self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) @@ -2710,20 +2656,17 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: "window attention is not supported for encoder-only models" attn_specs.append( - FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla, - )) + FullAttentionSpec(block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla)) else: raise ValueError("Expected only encoder-only layers") if len(attn_specs) > 0: - assert len(attn_specs) == len( - attn_layers - ), "All or none of the layers are expected to be encoder-only" + assert len(attn_specs) == len(attn_layers), \ + "All or none of the layers are expected to be encoder-only" attn_backend, attn_metadata_builder = ( self._initialize_single_attn_backend(attn_specs[0], @@ -2732,8 +2675,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.attn_metadata_builders.append(attn_metadata_builder) self.is_encoder_only_model = True - def calculate_reorder_batch_threshold( - self) -> None: + def calculate_reorder_batch_threshold(self) -> None: """ Check that if any backends reorder batches; that the reordering is compatible (e.g., decode threshold is the same) @@ -2796,7 +2738,7 @@ def _allocate_kv_cache_tensors( Returns: dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: tensor = torch.zeros(kv_cache_tensor.size, @@ -2836,16 +2778,13 @@ def _reshape_kv_cache_tensors( for layer_name in kv_cache_group_spec.layer_names: raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = raw_tensor.numel( - ) // kv_cache_spec.page_size_bytes + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - ) + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype try: kv_cache_stride_order = self.attn_backends[ @@ -2867,16 +2806,15 @@ def _reshape_kv_cache_tensors( kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = ( - kv_cache_raw_tensors[layer_name].view(dtype).view( - kv_cache_shape).permute(*inv_order)) + kv_caches[layer_name] = kv_cache_raw_tensors[ + layer_name].view(dtype).view(kv_cache_shape).permute( + *inv_order) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] dtype = kv_cache_spec.dtype num_element_per_page = (kv_cache_spec.page_size_bytes // get_dtype_size(dtype)) - state_tensors = [] storage_offset = 0 for shape in kv_cache_spec.shapes: @@ -2903,10 +2841,8 @@ def _reshape_kv_cache_tensors( return kv_caches def _verify_hybrid_attention_mamba_layout( - self, - kv_cache_config: KVCacheConfig, - kv_cache_raw_tensors: dict[str, torch.Tensor], - ) -> None: + self, kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None: """ Verify that the KV cache memory layout is compatible for models with both attention and mamba KV cache groups. @@ -2921,15 +2857,12 @@ def _verify_hybrid_attention_mamba_layout( kv_cache_spec = kv_cache_group_spec.kv_cache_spec for layer_name in kv_cache_group_spec.layer_names: raw_tensor = kv_cache_raw_tensors[layer_name] - num_blocks = raw_tensor.numel( - ) // kv_cache_spec.page_size_bytes + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - ) + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) if kv_cache_shape[0] != num_blocks or kv_cache_shape[ 1] != 2: raise ValueError( @@ -3037,12 +2970,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, - use_mla=use_mla, - ) + use_mla=use_mla) assert not use_local_attention, ( "attention module can not be with ", - "both local attention and sliding window", - ) + "both local attention and sliding window") elif use_local_attention: kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( block_size=block_size, @@ -3050,20 +2981,16 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, attention_chunk_size=self.attention_chunk_size, - use_mla=use_mla, - ) + use_mla=use_mla) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla, - ) - elif attn_module.attn_type in ( - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - ): + use_mla=use_mla) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: @@ -3082,7 +3009,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len - page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded + page_size_padded = ( + self.vllm_config.cache_config.mamba_page_size_padded) # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. @@ -3092,14 +3020,13 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, block_size=max_model_len, page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type, - ) + mamba_type=mamba_module.mamba_type) return kv_cache_spec def _build_encoder_only_attn_metadata( - self, scheduler_output: "SchedulerOutput" - ) -> tuple[CommonAttentionMetadata, Any]: + self, scheduler_output: "SchedulerOutput") -> \ + tuple[CommonAttentionMetadata, Any]: """Prepare encoder attention metadata for encoder-only models. Args: From 4a573959b3699c527cc329e5400c95c2e672d4cc Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 3 Aug 2025 15:14:49 +0300 Subject: [PATCH 32/44] fix: Lint Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_utils.py | 10 ++++++---- vllm/model_executor/models/jamba.py | 1 + vllm/model_executor/models/mamba.py | 3 ++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 4cb20debb5e4..42c815b08f04 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm import envs from vllm.distributed import divide @@ -13,6 +12,7 @@ def mamba1_state_shape( intermediate_size: int, state_size: int, conv_kernel: int, + use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int]]: conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) @@ -20,9 +20,11 @@ def mamba1_state_shape( temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) - if envs.VLLM_USE_V1: - return (conv_state_shape[1], - conv_state_shape[0]), temporal_state_shape + # In V0, the conv_state shape was swapped during allocation in + # MambaCacheManager, but in V1 it needs to be determined here at the + # calculation level + if use_v1: + conv_state_shape = conv_state_shape[1], conv_state_shape[0] return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index ca936c96a57d..eaf498dab359 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -554,6 +554,7 @@ def get_mamba_state_shape_from_config( intermediate_size=hf_config.mamba_expand * hidden_size, state_size=hf_config.mamba_d_state, conv_kernel=hf_config.mamba_d_conv, + use_v1=envs.VLLM_USE_V1, ) def compute_logits( diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 8a57db1cffde..80b63e15377a 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -266,7 +266,8 @@ def get_mamba_state_shape_from_config( tp_world_size=parallel_config.tensor_parallel_size, intermediate_size=hf_config.intermediate_size, state_size=hf_config.state_size, - conv_kernel=hf_config.conv_kernel) + conv_kernel=hf_config.conv_kernel, + use_v1=envs.VLLM_USE_V1) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.mamba_cache.copy_inputs_before_cuda_graphs( From a0958b4bea3892b471afc04e8c2e7077d180be63 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 3 Aug 2025 15:19:39 +0300 Subject: [PATCH 33/44] fix: Lint in tests Signed-off-by: asafg --- tests/models/language/generation/test_hybrid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index f62e1b7128d3..f7834d98137e 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -54,6 +54,7 @@ V1_SUPPORTED_MODELS = [ "state-spaces/mamba-130m-hf", + "ai21labs/Jamba-tiny-dev", "mistralai/Mamba-Codestral-7B-v0.1", "ibm-ai-platform/Bamba-9B-v1", "Zyphra/Zamba2-1.2B-instruct", @@ -102,8 +103,8 @@ def test_models( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: + # required due to reorder_batch behaviour m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - if model in SSM_MODELS: # Set to True until support in CUDA Graphs enforce_eager = True @@ -377,7 +378,6 @@ def test_distributed_correctness( max_tokens: int, num_logprobs: int, ) -> None: - # Set enforce_eager=True until support in CUDA Graphs with vllm_runner(model, tensor_parallel_size=1, max_num_seqs=2) as vllm_model: vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs( From 9b8fd69ad0a7155e99def1fbb42a789e957837a5 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 3 Aug 2025 15:48:24 +0300 Subject: [PATCH 34/44] refactor: Added v1 jamba support Signed-off-by: asafg --- vllm/model_executor/models/jamba.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index eaf498dab359..8a9efd4d7247 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -35,8 +35,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsV0Only) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -446,9 +445,8 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -# TODO: Remove SupportsV0Only once v1 is fully supported class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsV0Only): + IsHybrid): hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ ".self_attn.": ".", ".A_log": ".A" From adcb205af2475235b59659fc8699fdaa721ef4c6 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 3 Aug 2025 19:16:53 +0300 Subject: [PATCH 35/44] fix: Lint Signed-off-by: asafg --- .../layers/mamba/mamba_mixer.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 8b5ae1fb8a63..2ce52fc69676 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -274,19 +274,18 @@ def forward_cuda(self, query_start_loc=query_start_loc) else: scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) - selective_state_update( - ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, - B, - C, - self.D, - gate.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=state_indices_tensor, - out=scan_outputs) + selective_state_update(ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=state_indices_tensor, + out=scan_outputs) scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection From 5db14c55c06bb8e06dc56a4f6b4ecddb579eeffa Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 3 Aug 2025 19:29:52 +0300 Subject: [PATCH 36/44] fix: Removed unnecessary props from mamba1_attn Signed-off-by: asafg --- .../layers/mamba/mamba_mixer.py | 5 ++-- vllm/v1/attention/backends/mamba1_attn.py | 23 +++---------------- vllm/v1/worker/gpu_model_runner.py | 20 ++++++++-------- 3 files changed, 15 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 2ce52fc69676..02132f19c26c 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -167,9 +167,8 @@ def forward_cuda(self, if envs.VLLM_USE_V1: if attn_metadata is not None: - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.prefix] - + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] mamba1_metadata = attn_metadata assert isinstance(mamba1_metadata, Mamba1AttentionMetadata) query_start_loc = mamba1_metadata.query_start_loc diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 7d482b37a0ab..b14741265278 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -10,7 +10,7 @@ from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) + reorder_batch_to_split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec if TYPE_CHECKING: @@ -27,12 +27,7 @@ def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: @dataclass class Mamba1AttentionMetadata: - num_prefills: int - num_prefill_tokens: int - num_decodes: int - num_decode_tokens: int query_start_loc: torch.Tensor - seq_lens: torch.Tensor context_lens_tensor: torch.Tensor state_indices_tensor: torch.Tensor has_initial_states: torch.Tensor @@ -46,11 +41,13 @@ def __init__( kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device, + layer_names: list[str], ): assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec self.device = device self.vllm_config = vllm_config + self.layer_names = layer_names def reorder_batch( self, @@ -71,28 +68,14 @@ def build( ) -> Mamba1AttentionMetadata: query_start_loc = common_attn_metadata.query_start_loc - seq_lens = (common_attn_metadata.seq_lens.to( - query_start_loc.device).to(torch.int32)) - - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=1, - ) - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( query_start_loc.device) has_initial_states = (context_lens_tensor > 0) return Mamba1AttentionMetadata( - num_prefills=num_prefills, - num_prefill_tokens=num_prefill_tokens, - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, query_start_loc=query_start_loc, context_lens_tensor=context_lens_tensor, - seq_lens=seq_lens, has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 871ca97cb555..2cea87c3efc2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1425,10 +1425,10 @@ def _pool( num_scheduled_tokens_np: np.ndarray, kv_connector_output: Optional[KVConnectorOutput], ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs == len( - self.input_batch.pooling_params), ( - "Either all or none of the requests in" - " a batch must be pooling request") + assert self.input_batch.num_reqs ==\ + len(self.input_batch.pooling_params), \ + "Either all or none of the requests in" \ + " a batch must be pooling request" extracted_hidden_states = list( torch.split(hidden_states[:num_scheduled_tokens], @@ -1676,8 +1676,8 @@ def execute_model( # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = (logprobs_tensors.tolists() - if logprobs_tensors is not None else None) + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( @@ -2597,10 +2597,10 @@ def _initialize_single_attn_backend( f"Unknown KV cache spec type: {type(kv_cache_spec)}") attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - kv_cache_spec, - layer_names, - self.vllm_config, - self.device, + kv_cache_spec=kv_cache_spec, + layer_names=layer_names, + vllm_config=self.vllm_config, + device=self.device, ) if self.full_cuda_graph: From 24a1aa6f374b87142018b08893e4f12809efba17 Mon Sep 17 00:00:00 2001 From: asafg Date: Sun, 3 Aug 2025 20:36:09 +0300 Subject: [PATCH 37/44] test: Removed mamba1 from unsupported models in test_oracle Signed-off-by: asafg --- tests/v1/test_oracle.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index b68ed298a189..a756c89b520f 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -12,7 +12,6 @@ UNSUPPORTED_MODELS_V1 = [ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder - "state-spaces/mamba-130m-hf", # mamba1 ] MODEL = "meta-llama/Llama-3.2-1B-Instruct" From 23d09f893935a70fee9260d4d4e97fbb5929198c Mon Sep 17 00:00:00 2001 From: asafg Date: Mon, 4 Aug 2025 10:58:00 +0300 Subject: [PATCH 38/44] fix: Added stride to non-varlen in kernel Signed-off-by: asafg --- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 6 +++++- vllm/model_executor/layers/mamba/mamba_mixer.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 6485fbaa1d25..c4ddbc142791 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -135,7 +135,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + cache_index * params.ssm_states_batch_stride + dim_id * kNRows * params.ssm_states_dim_stride; - + float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { #pragma unroll @@ -515,6 +515,10 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, } params.out_batch_stride = out.stride(0); params.out_d_stride = out.stride(1); + + params.ssm_states_batch_stride = ssm_states.stride(0); + params.ssm_states_dim_stride = ssm_states.stride(1); + params.ssm_states_dstate_stride = ssm_states.stride(2); } } diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 02132f19c26c..e6a28ad872e9 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -154,8 +154,9 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): self.prefix = prefix - def forward_native(self, hidden_states: torch.Tensor, - conv_state: torch.Tensor, ssm_state: torch.Tensor): + def forward_native(self, + hidden_states: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None): pass def forward_cuda(self, From 1cc34d128a94a8fff968e6fef8778c98703973d3 Mon Sep 17 00:00:00 2001 From: asafg Date: Mon, 4 Aug 2025 11:12:35 +0300 Subject: [PATCH 39/44] docs: Updated docs to show mamba1 is supported in v1 Signed-off-by: asafg --- docs/models/supported_models.md | 4 ++-- docs/usage/v1_guide.md | 12 +++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index c058c20f1ed7..49ca2a3b8c74 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -367,9 +367,9 @@ th { | `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | -| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | +| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | +| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | | `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | | `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 38399c6633bd..d339401aea33 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | **Decoder-only Models** | 🚀 Optimized | | **Encoder-Decoder Models** | 🟠 Delayed | | **Embedding Models** | 🟢 Functional | -| **Mamba Models** | 🟢 (Mamba-2), 🟡 (Mamba-1) | +| **Mamba Models** | 🟢 (Mamba-2), 🟢 (Mamba-1) | | **Multimodal Models** | 🟢 Functional | vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol. @@ -104,13 +104,11 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models -Models using selective state-space mechanisms instead of standard transformer attention are partially supported. -Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers -(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require -disabling prefix caching in V1. +Models using selective state-space mechanisms instead of standard transformer attention are supported. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. -Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, -`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that +Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that these models currently require disabling prefix caching and using the FlashInfer attention backend in V1. #### Encoder-Decoder Models From 4a7e7f10299150407300eb6452cc3c456fdc84df Mon Sep 17 00:00:00 2001 From: asafg Date: Mon, 4 Aug 2025 11:13:28 +0300 Subject: [PATCH 40/44] fix: Lint Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_mixer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index e6a28ad872e9..d39b862ad321 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -155,8 +155,8 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): self.prefix = prefix def forward_native(self, - hidden_states: torch.Tensor, - mamba_cache_params: Optional[MambaCacheParams] = None): + hidden_states: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None): pass def forward_cuda(self, From f64191b0943d18f1d2fd097d1dcb3a9b6eab3dc7 Mon Sep 17 00:00:00 2001 From: asafg Date: Mon, 4 Aug 2025 13:57:47 +0300 Subject: [PATCH 41/44] fix: Added call to forward_cuda Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_mixer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index d39b862ad321..6f4177f6c849 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -157,7 +157,7 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): def forward_native(self, hidden_states: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): - pass + return self.forward_cuda(hidden_states, mamba_cache_params) def forward_cuda(self, hidden_states: torch.Tensor, From 8868aad94e848b2a2b26d48ed65763708553e02d Mon Sep 17 00:00:00 2001 From: asafg Date: Mon, 4 Aug 2025 14:08:18 +0300 Subject: [PATCH 42/44] test: Removed enforce_eager Signed-off-by: asafg --- tests/models/language/generation/test_hybrid.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index f7834d98137e..c0ae783f8534 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -99,19 +99,15 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: - enforce_eager = False with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: # required due to reorder_batch behaviour m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - # Set to True until support in CUDA Graphs - enforce_eager = True with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - enable_prefix_caching=False, - enforce_eager=enforce_eager) as vllm_model: + enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) else: From e6f77814d2e4cf4300a183815aca4868645388a1 Mon Sep 17 00:00:00 2001 From: asafg Date: Mon, 4 Aug 2025 18:46:00 +0300 Subject: [PATCH 43/44] fix: Moved call to forward Signed-off-by: asafg --- vllm/model_executor/layers/mamba/mamba_mixer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 6f4177f6c849..17b7f84a933f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -154,10 +154,18 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): self.prefix = prefix + def forward(self, + hidden_states: torch.Tensor, + mamba_cache_params: Optional[MambaCacheParams] = None): + if not envs.VLLM_USE_V1: + return CustomOp.forward(self, hidden_states, mamba_cache_params) + else: + return self.forward_cuda(hidden_states, mamba_cache_params) + def forward_native(self, hidden_states: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): - return self.forward_cuda(hidden_states, mamba_cache_params) + pass def forward_cuda(self, hidden_states: torch.Tensor, From 9ab94d41989663d1c3a7f4be9cc8ad5bf57a1663 Mon Sep 17 00:00:00 2001 From: asafg Date: Wed, 6 Aug 2025 09:26:53 +0300 Subject: [PATCH 44/44] fix: CR comments Signed-off-by: asafg --- docs/usage/v1_guide.md | 2 +- .../models/language/generation/test_hybrid.py | 1 - vllm/v1/attention/backends/mamba1_attn.py | 24 ++++--------------- vllm/v1/worker/gpu_model_runner.py | 8 +++---- 4 files changed, 10 insertions(+), 25 deletions(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index d339401aea33..d30144e8a825 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -105,7 +105,7 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models Models using selective state-space mechanisms instead of standard transformer attention are supported. -Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`. Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, `Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index c0ae783f8534..67ba2f25593d 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -104,7 +104,6 @@ def test_models( if model in HYBRID_MODELS: # required due to reorder_batch behaviour m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, enable_prefix_caching=False) as vllm_model: diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index b14741265278..f0e4636fdb52 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,21 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import ClassVar import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - class Mamba1AttentionBackend(AttentionBackend): @@ -36,6 +31,8 @@ class Mamba1AttentionMetadata: class Mamba1AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba1AttentionMetadata]): + reorder_batch_threshold: ClassVar[int] = 1 + def __init__( self, kv_cache_spec: AttentionSpec, @@ -49,17 +46,6 @@ def __init__( self.vllm_config = vllm_config self.layer_names = layer_names - def reorder_batch( - self, - input_batch: "InputBatch", - scheduler_output: "SchedulerOutput", - ) -> bool: - return reorder_batch_to_split_decodes_and_prefills( - input_batch, - scheduler_output, - decode_threshold=1, - ) - def build( self, common_prefix_len: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2cea87c3efc2..041687ae28b2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2597,10 +2597,10 @@ def _initialize_single_attn_backend( f"Unknown KV cache spec type: {type(kv_cache_spec)}") attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - kv_cache_spec=kv_cache_spec, - layer_names=layer_names, - vllm_config=self.vllm_config, - device=self.device, + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, ) if self.full_cuda_graph: