Skip to content

Commit 38648e9

Browse files
committed
[None][feat] Switch to internal version of MMProjector in Gemma3
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent f923974 commit 38648e9

File tree

1 file changed

+62
-17
lines changed

1 file changed

+62
-17
lines changed

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
import torch
77
from transformers import AutoProcessor, Gemma3Config, PreTrainedModel
8-
from transformers.modeling_utils import no_init_weights
9-
from transformers.models.gemma3.modeling_gemma3 import Gemma3MultiModalProjector
108

119
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
1210
BaseWeightMapper
@@ -18,6 +16,8 @@
1816
from ...sampling_params import SamplingParams
1917
from ..attention_backend import AttentionMetadata
2018
from ..model_config import ModelConfig
19+
from ..modules.linear import Linear
20+
from ..modules.rms_norm import RMSNorm
2121
from .modeling_gemma3 import Gemma3ForCausalLM
2222
from .modeling_multimodal_utils import fuse_input_embeds
2323
from .modeling_siglip import SiglipVisionModel
@@ -81,6 +81,61 @@ def __call__(
8181
return input_ids[0].to(torch.int32).tolist(), multimodal_data
8282

8383

84+
# Original HF implementation:
85+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/modeling_gemma3.py#L684
86+
class Gemma3MultiModalProjector(torch.nn.Module):
87+
"""Gemma3MultiModalProjector using TRTLLM's Linear and RMSNorm."""
88+
89+
def __init__(self, model_config: ModelConfig[Gemma3Config]):
90+
super().__init__()
91+
config = model_config.pretrained_config
92+
self.dtype = config.torch_dtype
93+
self.mm_input_projection = Linear(
94+
in_features=config.vision_config.hidden_size,
95+
out_features=config.text_config.hidden_size,
96+
bias=False,
97+
dtype=self.dtype,
98+
mapping=model_config.mapping)
99+
self.mm_soft_emb_norm = RMSNorm(
100+
hidden_size=config.vision_config.hidden_size,
101+
eps=config.vision_config.layer_norm_eps,
102+
dtype=self.dtype)
103+
self.patches_per_image = int(config.vision_config.image_size //
104+
config.vision_config.patch_size)
105+
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
106+
self.kernel_size = self.patches_per_image // self.tokens_per_side
107+
self.avg_pool = torch.nn.AvgPool2d(kernel_size=self.kernel_size,
108+
stride=self.kernel_size)
109+
110+
def load_weights(self, weights):
111+
# Original `mm_input_projection_weight` is a matmul while we use a linear op with no bias.
112+
self.mm_input_projection.weight.data.copy_(
113+
weights["mm_input_projection_weight"].transpose(0, 1))
114+
# Gemma3RmsNorm is a layernorm-1P and needs a +1.0.
115+
self.mm_soft_emb_norm.weight.data.copy_(
116+
weights["mm_soft_emb_norm.weight"] + 1.0)
117+
118+
@torch.inference_mode()
119+
def forward(self, vision_outputs: torch.Tensor):
120+
batch_size, _, seq_length = vision_outputs.shape
121+
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
122+
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
123+
batch_size, seq_length, self.patches_per_image,
124+
self.patches_per_image)
125+
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
126+
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs).to(
127+
self.dtype)
128+
pooled_vision_outputs = pooled_vision_outputs.flatten(2) # [B, T, P*P].
129+
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2).reshape(
130+
-1, seq_length)
131+
# FlashInfer's rmsnorm needs input to be contiguous.
132+
normed_vision_outputs = self.mm_soft_emb_norm(
133+
pooled_vision_outputs.contiguous())
134+
projected_vision_outputs = self.mm_input_projection(
135+
normed_vision_outputs)
136+
return projected_vision_outputs.type_as(vision_outputs)
137+
138+
84139
@register_auto_model("Gemma3ForConditionalGeneration")
85140
@register_input_processor(Gemma3InputProcessor, model_type="gemma3")
86141
class Gemma3VLM(PreTrainedModel):
@@ -114,11 +169,8 @@ def __init__(self, model_config: ModelConfig[Gemma3Config]):
114169
self.siglip_tower = SiglipVisionModel(vision_model_config,
115170
use_post_layernorm=True)
116171

117-
# NOTE: Use HF implementation. We init the weights after transferring to the `device` since it can take a much
118-
# longer time to initialize them on the CPU.
119-
with no_init_weights():
120-
self.mm_projector = Gemma3MultiModalProjector(config).eval().to(
121-
self._device)
172+
self.mm_projector = Gemma3MultiModalProjector(model_config).eval().to(
173+
self._device)
122174

123175
self.post_config()
124176
self.is_loaded = True
@@ -153,12 +205,8 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper):
153205
vit_weights = filter_weights("vision_tower", weights)
154206
self.siglip_tower.load_weights(vit_weights)
155207

156-
_load_weights_into_hf_module(
157-
model=self.mm_projector,
158-
weights=weights,
159-
prefix="multi_modal_projector",
160-
model_name="multi modal projector",
161-
)
208+
mm_projector_weights = filter_weights("multi_modal_projector", weights)
209+
self.mm_projector.load_weights(mm_projector_weights)
162210

163211
def post_config(self):
164212
self.config = self.llm.config
@@ -191,12 +239,9 @@ def forward(
191239
mm_embeds = []
192240
mm_token_mask = None
193241
if len(pixel_values) > 0:
194-
# The shape of `image_features` is `[B, T, embed_dim]`.
195242
image_features = self._get_image_features(
196243
pixel_values=torch.cat(pixel_values))
197-
# We need to reshape it to `[B * T, embed_dim]` before passing to `fuse_input_embeds`.
198-
B, T, embed_dim = image_features.shape
199-
mm_embeds = [image_features.reshape(B * T, embed_dim).contiguous()]
244+
mm_embeds = [image_features.contiguous()]
200245

201246
# Get token type ids. 0 corresponds to text tokens, 1 corresponds to image tokens.
202247
mm_token_mask = torch.isin(input_ids, self.image_token_ids)

0 commit comments

Comments
 (0)