Skip to content

Commit 5d962e8

Browse files
committed
feat: Add LoRA support for Gemma3
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent baece56 commit 5d962e8

File tree

5 files changed

+31
-33
lines changed

5 files changed

+31
-33
lines changed

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55
from torch import nn
66
from transformers import Gemma3TextConfig
7-
from transformers.activations import ACT2FN
87

98
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
109
BaseWeightMapper
@@ -20,7 +19,8 @@
2019
from ..modules.attention import Attention
2120
from ..modules.decoder_layer import DecoderLayer
2221
from ..modules.embedding import Embedding
23-
from ..modules.linear import Linear, TensorParallelMode
22+
from ..modules.gated_mlp import GatedMLP
23+
from ..modules.linear import TensorParallelMode
2424
from ..modules.multi_stream_utils import maybe_execute_in_parallel
2525
from ..modules.rms_norm import RMSNorm
2626
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
@@ -156,33 +156,10 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
156156
return super().apply_rope(q, k, v, position_ids)
157157

158158

159-
class Gemma3MLP(nn.Module):
160-
161-
def __init__(self, config: Gemma3TextConfig):
162-
super().__init__()
163-
self.config = config
164-
self.hidden_size = config.hidden_size
165-
self.intermediate_size = config.intermediate_size
166-
self.dtype = config.torch_dtype
167-
self.gate_proj = Linear(self.hidden_size,
168-
self.intermediate_size,
169-
bias=False,
170-
dtype=self.dtype)
171-
self.up_proj = Linear(self.hidden_size,
172-
self.intermediate_size,
173-
bias=False,
174-
dtype=self.dtype)
175-
self.down_proj = Linear(self.intermediate_size,
176-
self.hidden_size,
177-
bias=False,
178-
dtype=self.dtype)
179-
self.act_fn = ACT2FN[config.hidden_activation]
180-
181-
@torch.inference_mode()
182-
def forward(self, x):
183-
down_proj = self.down_proj(
184-
self.act_fn(self.gate_proj(x)) * self.up_proj(x))
185-
return down_proj
159+
# This function is written to be compatible with TRTLLM's GatedMLP class.
160+
def pytorch_gelu_tanh(gate_x: torch.Tensor) -> torch.Tensor:
161+
gate, x = gate_x.chunk(2, dim=-1)
162+
return nn.functional.gelu(gate, approximate="tanh") * x
186163

187164

188165
class Gemma3DecoderLayer(DecoderLayer):
@@ -202,7 +179,13 @@ def __init__(
202179
is_sliding=is_sliding,
203180
)
204181

205-
self.mlp = Gemma3MLP(config)
182+
self.mlp = GatedMLP(hidden_size=config.hidden_size,
183+
intermediate_size=config.intermediate_size,
184+
bias=False,
185+
activation=pytorch_gelu_tanh,
186+
dtype=config.torch_dtype,
187+
config=model_config,
188+
layer_idx=layer_idx)
206189

207190
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
208191
eps=config.rms_norm_eps,
@@ -226,6 +209,7 @@ def forward(
226209
attn_metadata: AttentionMetadata,
227210
residual: Optional[torch.Tensor] = None,
228211
attention_mask_data: Optional[torch.Tensor] = None,
212+
lora_params: Optional[dict] = None,
229213
**kwargs,
230214
) -> torch.Tensor:
231215

@@ -238,13 +222,14 @@ def forward(
238222
attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data
239223
is not None else PredefinedAttentionMask.CAUSAL,
240224
attention_mask_data=attention_mask_data,
225+
lora_params=lora_params,
241226
**kwargs,
242227
)
243228
hidden_states = self.post_attention_layernorm(hidden_states)
244229
hidden_states = residual + hidden_states
245230
residual = hidden_states
246231
hidden_states = self.pre_feedforward_layernorm(hidden_states)
247-
hidden_states = self.mlp(hidden_states)
232+
hidden_states = self.mlp(hidden_states, lora_params=lora_params)
248233
hidden_states = self.post_feedforward_layernorm(hidden_states)
249234
hidden_states = residual + hidden_states
250235

@@ -285,6 +270,7 @@ def forward(
285270
inputs_embeds: Optional[torch.FloatTensor] = None,
286271
local_attention_mask_data: Optional[torch.Tensor] = None,
287272
global_attention_mask_data: Optional[torch.Tensor] = None,
273+
lora_params: Optional[dict] = None,
288274
**kwargs,
289275
) -> torch.Tensor:
290276
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -304,7 +290,9 @@ def forward(
304290
attn_metadata=attn_metadata,
305291
attention_mask_data=local_attention_mask_data
306292
if decoder_layer.self_attn.is_sliding else
307-
global_attention_mask_data)
293+
global_attention_mask_data,
294+
lora_params=lora_params,
295+
)
308296

309297
hidden_states = self.norm(hidden_states)
310298
return hidden_states

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def forward(
213213
inputs_embeds=inputs_embeds,
214214
return_context_logits=return_context_logits,
215215
image_token_mask=mm_token_mask,
216+
lora_params=kwargs.get("lora_params", None),
216217
)
217218
return logits
218219

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def __init__(self,
108108
def _apply_activation(self, x):
109109
if self.activation == F.silu:
110110
return swiglu(x)
111-
elif self.activation == None:
111+
elif callable(self.activation):
112+
return self.activation(x)
113+
elif self.activation is None:
112114
return x
113115
else:
114116
raise NotImplementedError(

tests/integration/defs/perf/test_perf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
"mistral_7b_v0.1": "mistral-7b-v0.1",
9393
"ministral_8b": "Ministral-8B-Instruct-2410",
9494
"ministral_8b_fp8": "Ministral-8B-Instruct-2410-FP8",
95+
"gemma_3_1b_it": "gemma/gemma-3-1b-it",
9596
"deepseek_r1_fp8": "DeepSeek-R1/DeepSeek-R1",
9697
"deepseek_r1_nvfp4": "DeepSeek-R1/DeepSeek-R1-FP4",
9798
"deepseek_v3_lite_fp8": "DeepSeek-V3-Lite/fp8",
@@ -153,6 +154,7 @@
153154
"ministral_8b_hf": "mistralai/Ministral-8B-Instruct-2410",
154155
"flan_t5_base_hf": "google/flan-t5-small",
155156
"phi_4_mini_instruct_hf": "microsoft/Phi-4-mini-instruct",
157+
"gemma_3_1b_it_hf": "google/gemma-3-1b-it",
156158
}
157159
LORA_MODEL_PATH = {
158160
"llama_v2_13b":
@@ -163,6 +165,8 @@
163165
"lora/llama-3-chinese-8b-instruct-v2-lora/",
164166
"ministral_8b":
165167
"lora/ministral/Ministral-8B-Instruct-2410-Loras-Dummy", # Dummy LoRA for Ministral
168+
"gemma_3_1b_it":
169+
"lora/gemma/gemma-3-1b-it-dummy-lora", # Dummy LoRA for Gemma-3-1B-Instruct
166170
"phi_4_multimodal_instruct_image":
167171
"multimodals/Phi-4-multimodal-instruct/vision-lora",
168172
"phi_4_multimodal_instruct_audio":

tests/integration/test_lists/qa/trt_llm_integration_perf_test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ trt_llm_integration_perf_test:
4545
- perf/test_perf.py::test_perf[llama_v3.1_8b-cpp-ootb_except_mha-bfloat16-maxbs:64-bs:64-input_output_len:128,8+512,32]
4646
- perf/test_perf.py::test_perf[llama_v3.1_8b-cpp-ootb_except_mha-bfloat16-maxbs:64-bs:64-input_output_len:128,128+512,32]
4747

48+
# Dummy lora tests
49+
- perf/test_perf.py::test_perf[gemma_3_1b_it-bench-pytorch-bfloat16-maxbs:2-maxnt:1024-input_output_len:128,128-loras:1-reqs:8-con:2]
50+
4851
# Test list validation
4952
- test_list_validation.py::test_list_validation
5053

0 commit comments

Comments
 (0)