88from torch import nn
99from transformers import FalconH1Config
1010
11+ from vllm import envs
1112from vllm .attention .layer import Attention
1213from vllm .config import CacheConfig , VllmConfig
1314from vllm .distributed import divide , get_tensor_model_parallel_world_size
3334from vllm .model_executor .sampling_metadata import SamplingMetadata
3435from vllm .sequence import IntermediateTensors
3536
36- from .interfaces import (HasInnerState , IsHybrid , SupportsLoRA , SupportsPP ,
37- SupportsV0Only )
37+ from .interfaces import HasInnerState , IsHybrid , SupportsLoRA , SupportsPP
3838from .utils import (PPMissingLayer , is_pp_missing_parameter ,
3939 make_empty_intermediate_tensors_factory , make_layers ,
4040 maybe_prefix )
@@ -85,6 +85,7 @@ def __init__(
8585 config : FalconH1Config ,
8686 cache_config : Optional [CacheConfig ] = None ,
8787 quant_config : Optional [QuantizationConfig ] = None ,
88+ prefix : str = "" ,
8889 ) -> None :
8990 super ().__init__ ()
9091 self .config = config
@@ -107,6 +108,8 @@ def __init__(
107108 activation = config .hidden_act ,
108109 quant_config = quant_config ,
109110 use_rms_norm = config .mamba_rms_norm ,
111+ prefix = f"{ prefix } .mixer" ,
112+ chunk_size = config .mamba_chunk_size ,
110113 )
111114 # n_groups is overridden later by `MambaMixer2`
112115 self .groups_time_state_size = self .mamba .n_groups * config .mamba_d_state
@@ -316,18 +319,26 @@ def __init__(
316319 prefix : str = "" ,
317320 ) -> None :
318321 super ().__init__ ()
322+
319323 # Instantiate the attention branch
320324 self .self_attn = FalconH1AttentionDecoderLayer (
321325 config = config ,
322326 cache_config = cache_config ,
323327 quant_config = quant_config ,
324328 prefix = prefix ,
325329 )
330+
331+ # In V1 all attention/ssm layers must have
332+ # different index in prefix
333+ ssm_layer_idx = config .num_hidden_layers + layer_idx
334+ ssm_prefix = prefix .split ("." )[0 ] + f".{ ssm_layer_idx } "
335+
326336 # Instantiate the SSM branch
327337 self .mamba = FalconH1SSMDecoderLayer (
328338 config = config ,
329339 cache_config = cache_config ,
330340 quant_config = quant_config ,
341+ prefix = ssm_prefix ,
331342 )
332343 self .ssm_out_multiplier = config .ssm_out_multiplier
333344 self .ssm_in_multiplier = config .ssm_in_multiplier
@@ -452,10 +463,16 @@ def forward(
452463 # proper continuous batching computation including
453464 # chunked prefill
454465 attn_metadata = get_forward_context ().attn_metadata
455- mamba2_metadata = prepare_mamba2_metadata (
456- chunk_size = self .config .mamba_chunk_size ,
457- attn_metadata = attn_metadata ,
458- )
466+
467+ if not envs .VLLM_USE_V1 :
468+ mamba2_metadata = prepare_mamba2_metadata (
469+ chunk_size = self .config .mamba_chunk_size ,
470+ attn_metadata = attn_metadata ,
471+ )
472+ else :
473+ # v1 get mamba2_metadata from forward_context
474+ mamba2_metadata = None
475+
459476 if get_pp_group ().is_first_rank :
460477 if inputs_embeds is not None :
461478 hidden_states = inputs_embeds * self .embedding_multiplier
@@ -468,7 +485,9 @@ def forward(
468485
469486 for i in range (self .start_layer , self .end_layer ):
470487 layer = self .layers [i ]
471- layer_mamba_cache_params = mamba_cache_params .at_layer_idx (i )
488+ layer_mamba_cache_params = None
489+ if mamba_cache_params :
490+ layer_mamba_cache_params = mamba_cache_params .at_layer_idx (i )
472491 hidden_states = layer (
473492 positions = positions ,
474493 hidden_states = hidden_states ,
@@ -484,7 +503,7 @@ def forward(
484503
485504
486505class FalconH1ForCausalLM (nn .Module , HasInnerState , SupportsLoRA , SupportsPP ,
487- IsHybrid , SupportsV0Only ):
506+ IsHybrid ):
488507 packed_modules_mapping = {
489508 "qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ],
490509 "gate_up_proj" : ["gate_proj" , "up_proj" ],
@@ -558,15 +577,19 @@ def forward(
558577 inputs_embeds : Optional [torch .Tensor ] = None ,
559578 ** kwargs ,
560579 ):
561- if self .mamba_cache is None :
562- self .mamba_cache = MambaCacheManager (
563- self .vllm_config ,
564- self .lm_head .weight .dtype
565- if hasattr (self .lm_head , 'weight' ) else torch .bfloat16 ,
566- self .config .num_hidden_layers ,
567- * self ._get_mamba_cache_shape (),
568- )
569- mamba_cache_params = self .mamba_cache .current_run_tensors (** kwargs )
580+
581+ mamba_cache_params = None
582+ if not envs .VLLM_USE_V1 :
583+ if self .mamba_cache is None :
584+ self .mamba_cache = MambaCacheManager (
585+ self .vllm_config ,
586+ self .lm_head .weight .dtype if hasattr (
587+ self .lm_head , 'weight' ) else torch .bfloat16 ,
588+ self .config .num_hidden_layers ,
589+ * self ._get_mamba_cache_shape (),
590+ )
591+ mamba_cache_params = self .mamba_cache .current_run_tensors (** kwargs )
592+
570593 hidden_states = self .model (
571594 input_ids ,
572595 positions ,
0 commit comments