Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


@register_mapper("HF", "Gemma3ForCausalLM")
@register_mapper("HF", "Gemma3ForConditionalGeneration")
class Gemma3HfWeightMapper(HfWeightMapper):

def should_skip_module(self, module_name: str) -> bool:
Expand Down
15 changes: 10 additions & 5 deletions tensorrt_llm/_torch/models/modeling_gemma3vl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import dataclasses
import os
from typing import List, Optional, Tuple
Expand All @@ -7,6 +8,9 @@
from transformers.modeling_utils import no_init_weights
from transformers.models.gemma3.modeling_gemma3 import Gemma3MultiModalProjector

from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper

from ..._utils import nvtx_range
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
register_input_processor)
Expand Down Expand Up @@ -98,13 +102,14 @@ def __init__(self, model_config: ModelConfig[Gemma3Config]):
dtype=torch.int32,
device=self._device)

self.model_config = model_config
model_config_cp = copy.deepcopy(model_config)
self.model_config = model_config_cp

llm_model_config = self.get_sub_model_config(model_config,
llm_model_config = self.get_sub_model_config(model_config_cp,
"text_config")
self.llm = Gemma3ForCausalLM(llm_model_config)

vision_model_config = self.get_sub_model_config(model_config,
vision_model_config = self.get_sub_model_config(model_config_cp,
"vision_config")
self.siglip_tower = SiglipVisionModel(vision_model_config,
use_post_layernorm=True)
Expand Down Expand Up @@ -141,9 +146,9 @@ def get_sub_model_config(
sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype
return sub_model_config

def load_weights(self, weights):
def load_weights(self, weights, weight_mapper: BaseWeightMapper):
llm_weights = filter_weights("language_model", weights)
self.llm.load_weights(llm_weights)
self.llm.load_weights(llm_weights, weight_mapper)

vit_weights = filter_weights("vision_tower", weights)
self.siglip_tower.load_weights(vit_weights)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ l0_h100:
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
- condition:
ranges:
system_gpu_count:
Expand Down Expand Up @@ -193,7 +194,6 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
- condition:
ranges:
Expand Down