diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/gemma3_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/gemma3_weight_mapper.py index 3f35f2d9016..a8d31d6526d 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/gemma3_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/gemma3_weight_mapper.py @@ -6,6 +6,7 @@ @register_mapper("HF", "Gemma3ForCausalLM") +@register_mapper("HF", "Gemma3ForConditionalGeneration") class Gemma3HfWeightMapper(HfWeightMapper): def should_skip_module(self, module_name: str) -> bool: diff --git a/tensorrt_llm/_torch/models/modeling_gemma3vl.py b/tensorrt_llm/_torch/models/modeling_gemma3vl.py index d925b0c1db7..07fb5b5417b 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3vl.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3vl.py @@ -1,3 +1,4 @@ +import copy import dataclasses import os from typing import List, Optional, Tuple @@ -7,6 +8,9 @@ from transformers.modeling_utils import no_init_weights from transformers.models.gemma3.modeling_gemma3 import Gemma3MultiModalProjector +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ + BaseWeightMapper + from ..._utils import nvtx_range from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, register_input_processor) @@ -98,13 +102,14 @@ def __init__(self, model_config: ModelConfig[Gemma3Config]): dtype=torch.int32, device=self._device) - self.model_config = model_config + model_config_cp = copy.deepcopy(model_config) + self.model_config = model_config_cp - llm_model_config = self.get_sub_model_config(model_config, + llm_model_config = self.get_sub_model_config(model_config_cp, "text_config") self.llm = Gemma3ForCausalLM(llm_model_config) - vision_model_config = self.get_sub_model_config(model_config, + vision_model_config = self.get_sub_model_config(model_config_cp, "vision_config") self.siglip_tower = SiglipVisionModel(vision_model_config, use_post_layernorm=True) @@ -141,9 +146,9 @@ def get_sub_model_config( sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype return sub_model_config - def load_weights(self, weights): + def load_weights(self, weights, weight_mapper: BaseWeightMapper): llm_weights = filter_weights("language_model", weights) - self.llm.load_weights(llm_weights) + self.llm.load_weights(llm_weights, weight_mapper) vit_weights = filter_weights("vision_tower", weights) self.siglip_tower.load_weights(vit_weights) diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 3d115bc05b8..962b87abf72 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -75,6 +75,7 @@ l0_h100: - test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-] - test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test - test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B] + - test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] - condition: ranges: system_gpu_count: @@ -193,7 +194,6 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance] - - test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] - test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] - condition: ranges: