Skip to content

Commit 8151fd0

Browse files
committed
switch to GatedMLP
1 parent b754adf commit 8151fd0

File tree

2 files changed

+15
-30
lines changed

2 files changed

+15
-30
lines changed

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55
from torch import nn
66
from transformers import Gemma3TextConfig
7-
from transformers.activations import ACT2FN
87

98
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
109
BaseWeightMapper
@@ -18,6 +17,7 @@
1817
from ..distributed import AllReduceParams
1918
from ..model_config import ModelConfig
2019
from ..modules.attention import Attention
20+
from ..modules.gated_mlp import GatedMLP
2121
from ..modules.decoder_layer import DecoderLayer
2222
from ..modules.embedding import Embedding
2323
from ..modules.linear import Linear, TensorParallelMode
@@ -156,33 +156,10 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
156156
return super().apply_rope(q, k, v, position_ids)
157157

158158

159-
class Gemma3MLP(nn.Module):
160-
161-
def __init__(self, config: Gemma3TextConfig):
162-
super().__init__()
163-
self.config = config
164-
self.hidden_size = config.hidden_size
165-
self.intermediate_size = config.intermediate_size
166-
self.dtype = config.torch_dtype
167-
self.gate_proj = Linear(self.hidden_size,
168-
self.intermediate_size,
169-
bias=False,
170-
dtype=self.dtype)
171-
self.up_proj = Linear(self.hidden_size,
172-
self.intermediate_size,
173-
bias=False,
174-
dtype=self.dtype)
175-
self.down_proj = Linear(self.intermediate_size,
176-
self.hidden_size,
177-
bias=False,
178-
dtype=self.dtype)
179-
self.act_fn = ACT2FN[config.hidden_activation]
180-
181-
@torch.inference_mode()
182-
def forward(self, x):
183-
down_proj = self.down_proj(
184-
self.act_fn(self.gate_proj(x)) * self.up_proj(x))
185-
return down_proj
159+
# This function is written to be compatible with TRTLLM's GatedMLP class.
160+
def pytorch_gelu_tanh(gate_x: torch.Tensor) -> torch.Tensor:
161+
gate, x = gate_x.chunk(2, dim=-1)
162+
return nn.functional.gelu(gate, approximate="tanh") * x
186163

187164

188165
class Gemma3DecoderLayer(DecoderLayer):
@@ -202,7 +179,13 @@ def __init__(
202179
is_sliding=is_sliding,
203180
)
204181

205-
self.mlp = Gemma3MLP(config)
182+
self.mlp = GatedMLP(hidden_size=config.hidden_size,
183+
intermediate_size=config.intermediate_size,
184+
bias=False,
185+
activation=pytorch_gelu_tanh,
186+
dtype=config.torch_dtype,
187+
config=model_config,
188+
layer_idx=layer_idx)
206189

207190
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
208191
eps=config.rms_norm_eps,

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def __init__(self,
108108
def _apply_activation(self, x):
109109
if self.activation == F.silu:
110110
return swiglu(x)
111-
elif self.activation == None:
111+
elif callable(self.activation):
112+
return self.activation(x)
113+
elif self.activation is None:
112114
return x
113115
else:
114116
raise NotImplementedError(

0 commit comments

Comments
 (0)