44import torch
55from torch import nn
66from transformers import Gemma3TextConfig
7- from transformers .activations import ACT2FN
87
98from tensorrt_llm ._torch .models .checkpoints .base_weight_mapper import \
109 BaseWeightMapper
2019from ..modules .attention import Attention
2120from ..modules .decoder_layer import DecoderLayer
2221from ..modules .embedding import Embedding
23- from ..modules .linear import Linear , TensorParallelMode
22+ from ..modules .gated_mlp import GatedMLP
23+ from ..modules .linear import TensorParallelMode
2424from ..modules .multi_stream_utils import maybe_execute_in_parallel
2525from ..modules .rms_norm import RMSNorm
2626from .modeling_utils import (DecoderModel , DecoderModelForCausalLM ,
@@ -156,33 +156,10 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
156156 return super ().apply_rope (q , k , v , position_ids )
157157
158158
159- class Gemma3MLP (nn .Module ):
160-
161- def __init__ (self , config : Gemma3TextConfig ):
162- super ().__init__ ()
163- self .config = config
164- self .hidden_size = config .hidden_size
165- self .intermediate_size = config .intermediate_size
166- self .dtype = config .torch_dtype
167- self .gate_proj = Linear (self .hidden_size ,
168- self .intermediate_size ,
169- bias = False ,
170- dtype = self .dtype )
171- self .up_proj = Linear (self .hidden_size ,
172- self .intermediate_size ,
173- bias = False ,
174- dtype = self .dtype )
175- self .down_proj = Linear (self .intermediate_size ,
176- self .hidden_size ,
177- bias = False ,
178- dtype = self .dtype )
179- self .act_fn = ACT2FN [config .hidden_activation ]
180-
181- @torch .inference_mode ()
182- def forward (self , x ):
183- down_proj = self .down_proj (
184- self .act_fn (self .gate_proj (x )) * self .up_proj (x ))
185- return down_proj
159+ # This function is written to be compatible with TRTLLM's GatedMLP class.
160+ def pytorch_gelu_tanh (gate_x : torch .Tensor ) -> torch .Tensor :
161+ gate , x = gate_x .chunk (2 , dim = - 1 )
162+ return nn .functional .gelu (gate , approximate = "tanh" ) * x
186163
187164
188165class Gemma3DecoderLayer (DecoderLayer ):
@@ -202,7 +179,13 @@ def __init__(
202179 is_sliding = is_sliding ,
203180 )
204181
205- self .mlp = Gemma3MLP (config )
182+ self .mlp = GatedMLP (hidden_size = config .hidden_size ,
183+ intermediate_size = config .intermediate_size ,
184+ bias = False ,
185+ activation = pytorch_gelu_tanh ,
186+ dtype = config .torch_dtype ,
187+ config = model_config ,
188+ layer_idx = layer_idx )
206189
207190 self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
208191 eps = config .rms_norm_eps ,
@@ -226,6 +209,7 @@ def forward(
226209 attn_metadata : AttentionMetadata ,
227210 residual : Optional [torch .Tensor ] = None ,
228211 attention_mask_data : Optional [torch .Tensor ] = None ,
212+ lora_params : Optional [dict ] = None ,
229213 ** kwargs ,
230214 ) -> torch .Tensor :
231215
@@ -238,13 +222,14 @@ def forward(
238222 attention_mask = CustomAttentionMask .CUSTOM if attention_mask_data
239223 is not None else PredefinedAttentionMask .CAUSAL ,
240224 attention_mask_data = attention_mask_data ,
225+ lora_params = lora_params ,
241226 ** kwargs ,
242227 )
243228 hidden_states = self .post_attention_layernorm (hidden_states )
244229 hidden_states = residual + hidden_states
245230 residual = hidden_states
246231 hidden_states = self .pre_feedforward_layernorm (hidden_states )
247- hidden_states = self .mlp (hidden_states )
232+ hidden_states = self .mlp (hidden_states , lora_params = lora_params )
248233 hidden_states = self .post_feedforward_layernorm (hidden_states )
249234 hidden_states = residual + hidden_states
250235
@@ -285,6 +270,7 @@ def forward(
285270 inputs_embeds : Optional [torch .FloatTensor ] = None ,
286271 local_attention_mask_data : Optional [torch .Tensor ] = None ,
287272 global_attention_mask_data : Optional [torch .Tensor ] = None ,
273+ lora_params : Optional [dict ] = None ,
288274 ** kwargs ,
289275 ) -> torch .Tensor :
290276 if (input_ids is None ) ^ (inputs_embeds is not None ):
@@ -304,7 +290,9 @@ def forward(
304290 attn_metadata = attn_metadata ,
305291 attention_mask_data = local_attention_mask_data
306292 if decoder_layer .self_attn .is_sliding else
307- global_attention_mask_data )
293+ global_attention_mask_data ,
294+ lora_params = lora_params ,
295+ )
308296
309297 hidden_states = self .norm (hidden_states )
310298 return hidden_states
0 commit comments