@@ -34,15 +34,20 @@ def __init__(
3434 vllm_config : VllmConfig ,
3535 prefix : str = "" ,
3636 config : Optional [LlamaConfig ] = None ,
37+ layer_idx : int = 0 ,
3738 ) -> None :
3839 super ().__init__ (vllm_config , prefix = prefix , config = config )
3940
4041 config = config or vllm_config .model_config .hf_config
4142 quant_config = self .get_quant_config (vllm_config )
4243
44+ # First layer uses 2*hidden_size (embeds + hidden_states concatenated)
45+ # Subsequent layers use hidden_size (only hidden_states, no embeds)
46+ qkv_input_size = 2 * self .hidden_size if layer_idx == 0 else self .hidden_size
47+
4348 # override qkv
4449 self .self_attn .qkv_proj = QKVParallelLinear (
45- 2 * self . hidden_size ,
50+ qkv_input_size ,
4651 self .self_attn .head_dim ,
4752 self .self_attn .total_num_heads ,
4853 self .self_attn .total_num_kv_heads ,
@@ -52,6 +57,7 @@ def __init__(
5257 )
5358
5459 self .hidden_norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
60+ self .layer_idx = layer_idx
5561
5662 if getattr (config , "norm_before_residual" , False ):
5763 self ._residual_norm = self ._norm_before_residual
@@ -90,11 +96,15 @@ def forward(
9096 hidden_states : torch .Tensor ,
9197 residual : Optional [torch .Tensor ],
9298 ) -> tuple [torch .Tensor , torch .Tensor ]:
93- embeds = self .input_layernorm (embeds )
94-
95- hidden_states , residual = self ._residual_norm (hidden_states = hidden_states )
99+ if self .layer_idx == 0 :
100+ # First layer: concatenate embeds with hidden_states
101+ embeds = self .input_layernorm (embeds )
102+ hidden_states , residual = self ._residual_norm (hidden_states = hidden_states )
103+ hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
104+ else :
105+ # Subsequent layers: process hidden_states and residuals only
106+ hidden_states , residual = self .input_layernorm (hidden_states , residual )
96107
97- hidden_states = torch .cat ([embeds , hidden_states ], dim = - 1 )
98108 # Self Attention
99109 hidden_states = self .self_attn (
100110 positions = positions ,
@@ -133,9 +143,11 @@ def __init__(
133143 [
134144 LlamaDecoderLayer (
135145 current_vllm_config ,
136- prefix = maybe_prefix (prefix , f"layers.{ start_layer_id } " ),
146+ prefix = maybe_prefix (prefix , f"layers.{ layer_idx + start_layer_id } " ),
137147 config = self .config ,
148+ layer_idx = layer_idx ,
138149 )
150+ for layer_idx in range (self .config .num_hidden_layers )
139151 ]
140152 )
141153 if hasattr (self .config , "target_hidden_size" ):
@@ -166,13 +178,13 @@ def forward(
166178 assert hidden_states .shape [- 1 ] == input_embeds .shape [- 1 ]
167179
168180 residual = None
169- hidden_states , residual = self .layers [ 0 ](
170- positions ,
171- input_embeds ,
172- hidden_states ,
173- residual ,
174- )
175-
181+ for layer in self .layers :
182+ hidden_states , residual = layer (
183+ positions = positions ,
184+ embeds = input_embeds ,
185+ hidden_states = hidden_states ,
186+ residual = residual ,
187+ )
176188 hidden_states , hidden_prenorm = self .norm (hidden_states , residual )
177189 return hidden_states , hidden_prenorm
178190
0 commit comments