-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
Fix(gemma3_mm): Add robust weight loading for quantized VLM #20066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |||||||||||||
| from vllm.config import VllmConfig | ||||||||||||||
| from vllm.logger import init_logger | ||||||||||||||
| from vllm.model_executor.layers.layernorm import GemmaRMSNorm | ||||||||||||||
| from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||||||||||||||
| from vllm.model_executor.models.module_mapping import MultiModelKeys | ||||||||||||||
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||||||||||||||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||||||||||||||
|
|
@@ -35,9 +36,8 @@ | |||||||||||||
| from .interfaces import (MultiModalEmbeddings, SupportsLoRA, | ||||||||||||||
| SupportsMultiModal, SupportsPP) | ||||||||||||||
| from .siglip import SiglipVisionModel | ||||||||||||||
| from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, | ||||||||||||||
| init_vllm_registered_model, maybe_prefix, | ||||||||||||||
| merge_multimodal_embeddings) | ||||||||||||||
| from .utils import (WeightsMapper, flatten_bn, init_vllm_registered_model, | ||||||||||||||
| maybe_prefix, merge_multimodal_embeddings) | ||||||||||||||
|
|
||||||||||||||
| logger = init_logger(__name__) | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -707,8 +707,41 @@ def compute_logits( | |||||||||||||
|
|
||||||||||||||
| def load_weights(self, weights: Iterable[tuple[str, | ||||||||||||||
| torch.Tensor]]) -> set[str]: | ||||||||||||||
| loader = AutoWeightsLoader(self) | ||||||||||||||
| return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) | ||||||||||||||
| # Custom weight loader for Gemma3 VLM to handle naming inconsistencies. | ||||||||||||||
| # This loader first applies the class-level hf_to_vllm_mapper and then | ||||||||||||||
| # applies a targeted fix for the "double prefix" issue within the | ||||||||||||||
| # vision model component. | ||||||||||||||
|
|
||||||||||||||
| params_dict = dict(self.named_parameters()) | ||||||||||||||
| loaded_params: set[str] = set() | ||||||||||||||
|
|
||||||||||||||
| for original_name, loaded_weight in weights: | ||||||||||||||
| name = original_name | ||||||||||||||
|
|
||||||||||||||
| # Apply the standard class-level mapper first. | ||||||||||||||
| if self.hf_to_vllm_mapper is not None: | ||||||||||||||
| name = self.hf_to_vllm_mapper.map(name) | ||||||||||||||
|
|
||||||||||||||
| # Apply the targeted hotfix only if a vision weight is not found. | ||||||||||||||
| # This prevents regressions on other models. | ||||||||||||||
| if name not in params_dict and name.startswith("vision_model."): | ||||||||||||||
| potential_name = f"vision_model.{name}" | ||||||||||||||
| if potential_name in params_dict: | ||||||||||||||
| name = potential_name | ||||||||||||||
|
|
||||||||||||||
| # Load the weight using the potentially corrected name. | ||||||||||||||
| if name not in params_dict: | ||||||||||||||
| # Silently skip any weights that are still not found. | ||||||||||||||
| loaded_params.add(original_name) | ||||||||||||||
| continue | ||||||||||||||
|
|
||||||||||||||
| param = params_dict[name] | ||||||||||||||
| weight_loader = getattr(param, "weight_loader", | ||||||||||||||
| default_weight_loader) | ||||||||||||||
|
Comment on lines
+738
to
+740
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider checking if
Suggested change
|
||||||||||||||
| weight_loader(param, loaded_weight) | ||||||||||||||
| loaded_params.add(original_name) | ||||||||||||||
|
|
||||||||||||||
| return loaded_params | ||||||||||||||
|
|
||||||||||||||
| def get_mm_mapping(self) -> MultiModelKeys: | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's generally better to log a warning message when skipping weights, even if it's done silently. This can aid in debugging if unexpected weights are not loaded. Consider using
logger.warning.