Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand All @@ -35,9 +36,8 @@
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .utils import (WeightsMapper, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)

logger = init_logger(__name__)

Expand Down Expand Up @@ -707,8 +707,41 @@ def compute_logits(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
# Custom weight loader for Gemma3 VLM to handle naming inconsistencies.
# This loader first applies the class-level hf_to_vllm_mapper and then
# applies a targeted fix for the "double prefix" issue within the
# vision model component.

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()

for original_name, loaded_weight in weights:
name = original_name

# Apply the standard class-level mapper first.
if self.hf_to_vllm_mapper is not None:
name = self.hf_to_vllm_mapper.map(name)

# Apply the targeted hotfix only if a vision weight is not found.
# This prevents regressions on other models.
if name not in params_dict and name.startswith("vision_model."):
potential_name = f"vision_model.{name}"
if potential_name in params_dict:
name = potential_name

# Load the weight using the potentially corrected name.
if name not in params_dict:
# Silently skip any weights that are still not found.
loaded_params.add(original_name)
continue
Comment on lines +734 to +736
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's generally better to log a warning message when skipping weights, even if it's done silently. This can aid in debugging if unexpected weights are not loaded. Consider using logger.warning.

Suggested change
# Silently skip any weights that are still not found.
loaded_params.add(original_name)
continue
if name not in params_dict:
logger.warning(f"Skipping weight {original_name} as it is not found in the model.")
loaded_params.add(original_name)
continue


param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Comment on lines +738 to +740
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider checking if param has the attribute weight_loader before calling getattr. This can prevent potential AttributeError exceptions if a parameter unexpectedly lacks this attribute.

Suggested change
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
param = params_dict[name]
weight_loader = (getattr(param, "weight_loader", None) or
default_weight_loader)

weight_loader(param, loaded_weight)
loaded_params.add(original_name)

return loaded_params

def get_mm_mapping(self) -> MultiModelKeys:
"""
Expand Down