Skip to content

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Mar 28, 2025

What does this PR do?

This is a follow-up of #36963.

This PR makes _init_weights work seamlessly with composite models. Until this point, composite models would only use the _init_weights of the outer-most PreTrainedModel wrapper, leading to errors or skipped modules. Now, sub-models are correctly initialized according to their own _init_weights, without any overhead. This is increasingly important as most recent models are now multimodal.
Without this change, every composite model would have to recurse a second time on all sub-models explicitly in the outer-most _init_weights, which is extremely error prone and inefficient. E.g., we would need to do one or the other of the following in the outer-most _init_weights:

# FIRST BAD OPTION

def _init_weights(self, module):
    std = self.config.initializer_range
       
    # for each module in the model, check the whole module list of the submodel (very inefficient)
    if module in self.vision_tower.modules():
        self.vision_tower._init_weights(module)

    # similar for the other sub-model
    elif module in self.language_model.modules():
        self.language_model._init_weights(module)

    # usual init block for only the modules external to the sub-models
    elif isinstance(module, nn.Linear):
        ...

# OR EQUALLY INEFFICIENT

def _init_weights(self, module):
    std = self.config.initializer_range
       
    # Here, as `apply` is depth-first graph traversal, every module will be initialized a first time, then re-initialized
    # a second time (extremely inefficient as well)
    if module is self.vision_tower:
        self.vision_tower.apply(self.vision_tower._init_weights)

    # similar for the other sub-model
    elif module is self.language_model:
        self.language_model.apply(self.language_model._init_weights)

    # usual init block for only the modules external to the sub-models
    elif isinstance(module, nn.Linear):
        ...

This PR allows to simply do

def _init_weights(self, module):
    std = self.config.initializer_range

    # usual init block for only the modules external to the sub-models
    if isinstance(module, nn.Linear):
        ...

and have all submodels correctly initialized automatically.

Also, enforce torch.no_grad() for initialization, which was not the case before and would slow down the process.

Finally, fix the _init_weights of a LOT of models, the most important ones (the most recent ones, and the ones with the flag _supports_cache_class=True) for now.
The reason not to do them all is simply that there are too much to fix. Almost all models in the library have broken _ init_weights 🙃
We'll patch incrementally. In the meantime, the added test will enforce that new models are correct.

@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@github-actions github-actions bot marked this pull request as draft March 28, 2025 09:59
@Cyrilvallez Cyrilvallez marked this pull request as ready for review March 28, 2025 10:06
@github-actions github-actions bot requested a review from ydshieh March 28, 2025 10:06
@ydshieh
Copy link
Collaborator

ydshieh commented Mar 28, 2025

Before a huge refactorization, could you review this one and hopefully we can merge it ?🙏

@Cyrilvallez Cyrilvallez force-pushed the fix-all-init-weights branch from 8574b67 to 759c5c8 Compare March 31, 2025 15:38
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez Cyrilvallez force-pushed the fix-all-init-weights branch 2 times, most recently from c84f00b to 39ddd6e Compare April 2, 2025 13:23
@Cyrilvallez Cyrilvallez changed the title Detect and fix all _init_weights() issues Detect and fix most _init_weights() issues - make it work for composite models Apr 2, 2025
Copy link
Member Author

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ArthurZucker, I highlighted most important changes. Other changes are just the _init_weights fixing on all the models

Comment on lines +2475 to +2481
@torch.no_grad()
def initialize_weights(self):
"""
This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models.
This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite
model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which
is extremely error prone and inefficient.
Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use
`torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as
`module.weight.data.zero_()`.
"""
if not hasattr(torch.nn.Module, "smart_apply"):
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
# to apply as we go down the graph
def smart_apply(self, fn):
for module in self.children():
# We found a sub-model: recursively dispatch its own init function now!
if hasattr(module, "_init_weights"):
module.smart_apply(module._initialize_weights)
else:
module.smart_apply(fn)
fn(self)
return self

torch.nn.Module.smart_apply = smart_apply

# Let the magic happen with this simple call
self.smart_apply(self._initialize_weights)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the most important change to review @ArthurZucker. It's the most efficient and elegant way to handle it, as we only need to traverse modules once. However, it requires to hot-patch torch.nn.Module, which is a bummer but fine IMO.
Other options to avoid doing so all require to traverse the modules several times (at least 2 times) which is less efficient.

Comment on lines 511 to 521
filename = inspect.getfile(model_class)
# No easy way to get model addition date -> check copyright year on top of file
with open(filename) as file:
source_code = file.read()
addition_year = 0 # if we cannot find it, set it to 0 (i.e. oldest)
if match_object := re.search(r"^# Copyright (\d{4})", source_code, re.MULTILINE | re.IGNORECASE):
addition_year = int(match_object.group(1))

# For now, skip everything older than 2024 and "important models" (too much models to patch otherwise)
# Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them
# TODO: relax this as we patch more and more models
if addition_year < 2025 and not model_class._supports_cache_class:
self.skipTest(reason=f"{model_class} is not a priorited model for now.")
Copy link
Member Author

@Cyrilvallez Cyrilvallez Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there are too many models to patch otherwise, I'm just enforcing it for most recent/most important ones for now. Having the test like this will enforce that new models have correct init schemes, while we patch older models. Just using copyright date as a proxy of added time as it's the easiest way that came to mind

@Cyrilvallez Cyrilvallez force-pushed the fix-all-init-weights branch from 13ee1da to 8e55a6f Compare April 8, 2025 13:44
@ydshieh ydshieh removed their request for review April 14, 2025 09:25
@Cyrilvallez Cyrilvallez force-pushed the fix-all-init-weights branch from 6ef2c6d to e5d5ecc Compare April 14, 2025 10:21
@Cyrilvallez Cyrilvallez force-pushed the fix-all-init-weights branch from bc884ae to ce665b8 Compare April 14, 2025 11:04
@Cyrilvallez
Copy link
Member Author

run-slow: llama, mistral, mistral3, qwen2_5_vl

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/llama', 'models/mistral', 'models/mistral3', 'models/qwen2_5_vl']
quantizations: [] ...

@Cyrilvallez
Copy link
Member Author

Slow tests are similar as main, other tests are hub timeouts and flaky test for Qwen2.5 Omni
Merging

@Cyrilvallez Cyrilvallez merged commit 4e53840 into main Apr 14, 2025
17 of 22 checks passed
@Cyrilvallez Cyrilvallez deleted the fix-all-init-weights branch April 14, 2025 14:19
cyr0930 pushed a commit to cyr0930/transformers that referenced this pull request Apr 18, 2025
…site models (huggingface#37070)

* Update test_modeling_common.py

* Fix Llama and its modular children

* Update test_modeling_common.py

* qwen3

* first try at prioritizing models

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* test

* fix

* fix

* more models

* more

* more

* more

* smarter init for composite models!

* fix post rebase

* smol

* fix missing args

* more

* typo

* Super elegant and efficient init for submodels

* Update modeling_utils.py

* style

* last fixes

* cleanup

* finalize cleanup

* CIs

* improve docstring

* Update modeling_utils.py

* llama4

* style

* CIs

* style

* add dpt

* granite speech

* qwen 2.5 omni

* better fix

* Parse the config file instead

* CIs
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, not sure how much more efficient it is, but it should be a lot! 🤗

@molbap molbap mentioned this pull request Apr 28, 2025
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
…site models (huggingface#37070)

* Update test_modeling_common.py

* Fix Llama and its modular children

* Update test_modeling_common.py

* qwen3

* first try at prioritizing models

* Update test_modeling_common.py

* Update test_modeling_common.py

* Update test_modeling_common.py

* test

* fix

* fix

* more models

* more

* more

* more

* smarter init for composite models!

* fix post rebase

* smol

* fix missing args

* more

* typo

* Super elegant and efficient init for submodels

* Update modeling_utils.py

* style

* last fixes

* cleanup

* finalize cleanup

* CIs

* improve docstring

* Update modeling_utils.py

* llama4

* style

* CIs

* style

* add dpt

* granite speech

* qwen 2.5 omni

* better fix

* Parse the config file instead

* CIs
@ydshieh ydshieh mentioned this pull request May 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants