1414from ..attention_backend .interface import (AttentionMask , CustomAttentionMask ,
1515 PositionalEmbeddingParams ,
1616 PredefinedAttentionMask , RopeParams )
17- from ..distributed import AllReduceParams
1817from ..model_config import ModelConfig
1918from ..modules .attention import Attention
2019from ..modules .decoder_layer import DecoderLayer
@@ -105,9 +104,6 @@ def forward(
105104 hidden_states : torch .Tensor ,
106105 attn_metadata : AttentionMetadata ,
107106 attention_mask : AttentionMask = PredefinedAttentionMask .CAUSAL ,
108- mrope_config : Optional [dict ] = None ,
109- all_reduce_params : Optional [AllReduceParams ] = None ,
110- lora_params : Optional [dict ] = None ,
111107 attention_mask_data : Optional [torch .Tensor ] = None ,
112108 ** kwargs ,
113109 ) -> torch .Tensor :
@@ -121,9 +117,6 @@ def forward(
121117 hidden_states = hidden_states ,
122118 attn_metadata = attn_metadata ,
123119 attention_mask = attention_mask ,
124- mrope_config = mrope_config ,
125- all_reduce_params = all_reduce_params ,
126- lora_params = lora_params ,
127120 attention_window_size = self .attention_window_size ,
128121 attention_mask_data = attention_mask_data ,
129122 ** kwargs )
@@ -209,7 +202,6 @@ def forward(
209202 attn_metadata : AttentionMetadata ,
210203 residual : Optional [torch .Tensor ] = None ,
211204 attention_mask_data : Optional [torch .Tensor ] = None ,
212- lora_params : Optional [dict ] = None ,
213205 ** kwargs ,
214206 ) -> torch .Tensor :
215207
@@ -222,14 +214,14 @@ def forward(
222214 attention_mask = CustomAttentionMask .CUSTOM if attention_mask_data
223215 is not None else PredefinedAttentionMask .CAUSAL ,
224216 attention_mask_data = attention_mask_data ,
225- lora_params = lora_params ,
226217 ** kwargs ,
227218 )
228219 hidden_states = self .post_attention_layernorm (hidden_states )
229220 hidden_states = residual + hidden_states
230221 residual = hidden_states
231222 hidden_states = self .pre_feedforward_layernorm (hidden_states )
232- hidden_states = self .mlp (hidden_states , lora_params = lora_params )
223+ hidden_states = self .mlp (hidden_states ,
224+ lora_params = kwargs .get ("lora_params" , None ))
233225 hidden_states = self .post_feedforward_layernorm (hidden_states )
234226 hidden_states = residual + hidden_states
235227
@@ -270,7 +262,6 @@ def forward(
270262 inputs_embeds : Optional [torch .FloatTensor ] = None ,
271263 local_attention_mask_data : Optional [torch .Tensor ] = None ,
272264 global_attention_mask_data : Optional [torch .Tensor ] = None ,
273- lora_params : Optional [dict ] = None ,
274265 ** kwargs ,
275266 ) -> torch .Tensor :
276267 if (input_ids is None ) ^ (inputs_embeds is not None ):
@@ -291,7 +282,7 @@ def forward(
291282 attention_mask_data = local_attention_mask_data
292283 if decoder_layer .self_attn .is_sliding else
293284 global_attention_mask_data ,
294- lora_params = lora_params ,
285+ ** kwargs ,
295286 )
296287
297288 hidden_states = self .norm (hidden_states )
@@ -465,6 +456,7 @@ def forward(
465456 inputs_embeds = inputs_embeds ,
466457 local_attention_mask_data = local_attention_mask_data ,
467458 global_attention_mask_data = global_attention_mask_data ,
459+ ** kwargs ,
468460 )
469461
470462 return self .logits_processor .forward (
0 commit comments