44import torch
55from torch import nn
66from transformers import Gemma3TextConfig
7- from transformers .activations import ACT2FN
87
98from tensorrt_llm ._torch .models .checkpoints .base_weight_mapper import \
109 BaseWeightMapper
1817from ..distributed import AllReduceParams
1918from ..model_config import ModelConfig
2019from ..modules .attention import Attention
20+ from ..modules .gated_mlp import GatedMLP
2121from ..modules .decoder_layer import DecoderLayer
2222from ..modules .embedding import Embedding
2323from ..modules .linear import Linear , TensorParallelMode
@@ -156,33 +156,10 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
156156 return super ().apply_rope (q , k , v , position_ids )
157157
158158
159- class Gemma3MLP (nn .Module ):
160-
161- def __init__ (self , config : Gemma3TextConfig ):
162- super ().__init__ ()
163- self .config = config
164- self .hidden_size = config .hidden_size
165- self .intermediate_size = config .intermediate_size
166- self .dtype = config .torch_dtype
167- self .gate_proj = Linear (self .hidden_size ,
168- self .intermediate_size ,
169- bias = False ,
170- dtype = self .dtype )
171- self .up_proj = Linear (self .hidden_size ,
172- self .intermediate_size ,
173- bias = False ,
174- dtype = self .dtype )
175- self .down_proj = Linear (self .intermediate_size ,
176- self .hidden_size ,
177- bias = False ,
178- dtype = self .dtype )
179- self .act_fn = ACT2FN [config .hidden_activation ]
180-
181- @torch .inference_mode ()
182- def forward (self , x ):
183- down_proj = self .down_proj (
184- self .act_fn (self .gate_proj (x )) * self .up_proj (x ))
185- return down_proj
159+ # This function is written to be compatible with TRTLLM's GatedMLP class.
160+ def pytorch_gelu_tanh (gate_x : torch .Tensor ) -> torch .Tensor :
161+ gate , x = gate_x .chunk (2 , dim = - 1 )
162+ return nn .functional .gelu (gate , approximate = "tanh" ) * x
186163
187164
188165class Gemma3DecoderLayer (DecoderLayer ):
@@ -202,7 +179,13 @@ def __init__(
202179 is_sliding = is_sliding ,
203180 )
204181
205- self .mlp = Gemma3MLP (config )
182+ self .mlp = GatedMLP (hidden_size = config .hidden_size ,
183+ intermediate_size = config .intermediate_size ,
184+ bias = False ,
185+ activation = pytorch_gelu_tanh ,
186+ dtype = config .torch_dtype ,
187+ config = model_config ,
188+ layer_idx = layer_idx )
206189
207190 self .input_layernorm = RMSNorm (hidden_size = config .hidden_size ,
208191 eps = config .rms_norm_eps ,
0 commit comments