|
5 | 5 |
|
6 | 6 | import torch |
7 | 7 | from transformers import AutoProcessor, Gemma3Config, PreTrainedModel |
8 | | -from transformers.modeling_utils import no_init_weights |
9 | | -from transformers.models.gemma3.modeling_gemma3 import Gemma3MultiModalProjector |
10 | 8 |
|
11 | 9 | from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ |
12 | 10 | BaseWeightMapper |
|
18 | 16 | from ...sampling_params import SamplingParams |
19 | 17 | from ..attention_backend import AttentionMetadata |
20 | 18 | from ..model_config import ModelConfig |
| 19 | +from ..modules.linear import Linear |
| 20 | +from ..modules.rms_norm import RMSNorm |
21 | 21 | from .modeling_gemma3 import Gemma3ForCausalLM |
22 | 22 | from .modeling_multimodal_utils import fuse_input_embeds |
23 | 23 | from .modeling_siglip import SiglipVisionModel |
@@ -81,6 +81,61 @@ def __call__( |
81 | 81 | return input_ids[0].to(torch.int32).tolist(), multimodal_data |
82 | 82 |
|
83 | 83 |
|
| 84 | +# Original HF implementation: |
| 85 | +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/modeling_gemma3.py#L684 |
| 86 | +class Gemma3MultiModalProjector(torch.nn.Module): |
| 87 | + """Gemma3MultiModalProjector using TRTLLM's Linear and RMSNorm.""" |
| 88 | + |
| 89 | + def __init__(self, model_config: ModelConfig[Gemma3Config]): |
| 90 | + super().__init__() |
| 91 | + config = model_config.pretrained_config |
| 92 | + self.dtype = config.torch_dtype |
| 93 | + self.mm_input_projection = Linear( |
| 94 | + in_features=config.vision_config.hidden_size, |
| 95 | + out_features=config.text_config.hidden_size, |
| 96 | + bias=False, |
| 97 | + dtype=self.dtype, |
| 98 | + mapping=model_config.mapping) |
| 99 | + self.mm_soft_emb_norm = RMSNorm( |
| 100 | + hidden_size=config.vision_config.hidden_size, |
| 101 | + eps=config.vision_config.layer_norm_eps, |
| 102 | + dtype=self.dtype) |
| 103 | + self.patches_per_image = int(config.vision_config.image_size // |
| 104 | + config.vision_config.patch_size) |
| 105 | + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) |
| 106 | + self.kernel_size = self.patches_per_image // self.tokens_per_side |
| 107 | + self.avg_pool = torch.nn.AvgPool2d(kernel_size=self.kernel_size, |
| 108 | + stride=self.kernel_size) |
| 109 | + |
| 110 | + def load_weights(self, weights): |
| 111 | + # Original `mm_input_projection_weight` is a matmul while we use a linear op with no bias. |
| 112 | + self.mm_input_projection.weight.data.copy_( |
| 113 | + weights["mm_input_projection_weight"].transpose(0, 1)) |
| 114 | + # Gemma3RmsNorm is a layernorm-1P and needs a +1.0. |
| 115 | + self.mm_soft_emb_norm.weight.data.copy_( |
| 116 | + weights["mm_soft_emb_norm.weight"] + 1.0) |
| 117 | + |
| 118 | + @torch.inference_mode() |
| 119 | + def forward(self, vision_outputs: torch.Tensor): |
| 120 | + batch_size, _, seq_length = vision_outputs.shape |
| 121 | + reshaped_vision_outputs = vision_outputs.transpose(1, 2) |
| 122 | + reshaped_vision_outputs = reshaped_vision_outputs.reshape( |
| 123 | + batch_size, seq_length, self.patches_per_image, |
| 124 | + self.patches_per_image) |
| 125 | + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() |
| 126 | + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs).to( |
| 127 | + self.dtype) |
| 128 | + pooled_vision_outputs = pooled_vision_outputs.flatten(2) # [B, T, P*P]. |
| 129 | + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2).reshape( |
| 130 | + -1, seq_length) |
| 131 | + # FlashInfer's rmsnorm needs input to be contiguous. |
| 132 | + normed_vision_outputs = self.mm_soft_emb_norm( |
| 133 | + pooled_vision_outputs.contiguous()) |
| 134 | + projected_vision_outputs = self.mm_input_projection( |
| 135 | + normed_vision_outputs) |
| 136 | + return projected_vision_outputs.type_as(vision_outputs) |
| 137 | + |
| 138 | + |
84 | 139 | @register_auto_model("Gemma3ForConditionalGeneration") |
85 | 140 | @register_input_processor(Gemma3InputProcessor, model_type="gemma3") |
86 | 141 | class Gemma3VLM(PreTrainedModel): |
@@ -114,11 +169,8 @@ def __init__(self, model_config: ModelConfig[Gemma3Config]): |
114 | 169 | self.siglip_tower = SiglipVisionModel(vision_model_config, |
115 | 170 | use_post_layernorm=True) |
116 | 171 |
|
117 | | - # NOTE: Use HF implementation. We init the weights after transferring to the `device` since it can take a much |
118 | | - # longer time to initialize them on the CPU. |
119 | | - with no_init_weights(): |
120 | | - self.mm_projector = Gemma3MultiModalProjector(config).eval().to( |
121 | | - self._device) |
| 172 | + self.mm_projector = Gemma3MultiModalProjector(model_config).eval().to( |
| 173 | + self._device) |
122 | 174 |
|
123 | 175 | self.post_config() |
124 | 176 | self.is_loaded = True |
@@ -153,12 +205,8 @@ def load_weights(self, weights, weight_mapper: BaseWeightMapper): |
153 | 205 | vit_weights = filter_weights("vision_tower", weights) |
154 | 206 | self.siglip_tower.load_weights(vit_weights) |
155 | 207 |
|
156 | | - _load_weights_into_hf_module( |
157 | | - model=self.mm_projector, |
158 | | - weights=weights, |
159 | | - prefix="multi_modal_projector", |
160 | | - model_name="multi modal projector", |
161 | | - ) |
| 208 | + mm_projector_weights = filter_weights("multi_modal_projector", weights) |
| 209 | + self.mm_projector.load_weights(mm_projector_weights) |
162 | 210 |
|
163 | 211 | def post_config(self): |
164 | 212 | self.config = self.llm.config |
@@ -191,12 +239,9 @@ def forward( |
191 | 239 | mm_embeds = [] |
192 | 240 | mm_token_mask = None |
193 | 241 | if len(pixel_values) > 0: |
194 | | - # The shape of `image_features` is `[B, T, embed_dim]`. |
195 | 242 | image_features = self._get_image_features( |
196 | 243 | pixel_values=torch.cat(pixel_values)) |
197 | | - # We need to reshape it to `[B * T, embed_dim]` before passing to `fuse_input_embeds`. |
198 | | - B, T, embed_dim = image_features.shape |
199 | | - mm_embeds = [image_features.reshape(B * T, embed_dim).contiguous()] |
| 244 | + mm_embeds = [image_features.contiguous()] |
200 | 245 |
|
201 | 246 | # Get token type ids. 0 corresponds to text tokens, 1 corresponds to image tokens. |
202 | 247 | mm_token_mask = torch.isin(input_ids, self.image_token_ids) |
|
0 commit comments