1717from vllm .model_executor .custom_op import CustomOp
1818from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
1919 RowParallelLinear )
20+ from vllm .model_executor .layers .mamba .abstract import MambaBase
2021from vllm .model_executor .layers .mamba .mamba2_metadata import (Mamba2Metadata ,
2122 update_metadata )
2223from vllm .model_executor .layers .mamba .ops .causal_conv1d import (
@@ -219,7 +220,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
219220
220221# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
221222@CustomOp .register ("mamba_mixer2" )
222- class MambaMixer2 (CustomOp ):
223+ class MambaMixer2 (MambaBase , CustomOp ):
223224 """
224225 Compute ∆, A, B, C, and D the state space parameters and compute
225226 the `contextualized_states`. A, D are input independent
@@ -231,22 +232,21 @@ class MambaMixer2(CustomOp):
231232 """
232233
233234 def __init__ (
234- self ,
235- hidden_size : int ,
236- ssm_state_size : int ,
237- conv_kernel_size : int ,
238- intermediate_size : int ,
239- use_conv_bias : bool ,
240- use_bias : bool ,
241- n_groups : int = 1 ,
242- num_heads : int = 128 ,
243- head_dim : int = 64 ,
244- rms_norm_eps : float = 1e-5 ,
245- activation : str = "silu" ,
246- use_rms_norm : bool = True ,
247- quant_config : Optional [QuantizationConfig ] = None ,
248- prefix : str = "" ,
249- chunk_size : int = - 1 , # the chunk size used by v1
235+ self ,
236+ hidden_size : int ,
237+ ssm_state_size : int ,
238+ conv_kernel_size : int ,
239+ intermediate_size : int ,
240+ use_conv_bias : bool ,
241+ use_bias : bool ,
242+ n_groups : int = 1 ,
243+ num_heads : int = 128 ,
244+ head_dim : int = 64 ,
245+ rms_norm_eps : float = 1e-5 ,
246+ activation : str = "silu" ,
247+ use_rms_norm : bool = True ,
248+ quant_config : Optional [QuantizationConfig ] = None ,
249+ prefix : str = "" ,
250250 ):
251251 super ().__init__ ()
252252
@@ -428,10 +428,7 @@ def __init__(
428428 # of Attention + v0 PP.
429429 # The inner tuple is (conv_state, ssm_state)
430430 self .kv_cache = [(torch .tensor ([]), torch .tensor ([]))]
431- assert chunk_size != - 1 , "chunk_size must be set for v1"
432431
433- # NOTE: chunk_size may be -1 for models without v1 support
434- self .chunk_size = chunk_size
435432 self .prefix = prefix
436433
437434 def forward_native (
0 commit comments