-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Detect and fix most _init_weights()
issues - make it work for composite models
#37070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
Before a huge refactorization, could you review this one and hopefully we can merge it ?🙏 |
8574b67
to
759c5c8
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
c84f00b
to
39ddd6e
Compare
_init_weights()
issues_init_weights()
issues - make it work for composite models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ArthurZucker, I highlighted most important changes. Other changes are just the _init_weights
fixing on all the models
@torch.no_grad() | ||
def initialize_weights(self): | ||
""" | ||
This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models. | ||
This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the | ||
module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite | ||
model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which | ||
is extremely error prone and inefficient. | ||
Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use | ||
`torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as | ||
`module.weight.data.zero_()`. | ||
""" | ||
if not hasattr(torch.nn.Module, "smart_apply"): | ||
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function | ||
# to apply as we go down the graph | ||
def smart_apply(self, fn): | ||
for module in self.children(): | ||
# We found a sub-model: recursively dispatch its own init function now! | ||
if hasattr(module, "_init_weights"): | ||
module.smart_apply(module._initialize_weights) | ||
else: | ||
module.smart_apply(fn) | ||
fn(self) | ||
return self | ||
|
||
torch.nn.Module.smart_apply = smart_apply | ||
|
||
# Let the magic happen with this simple call | ||
self.smart_apply(self._initialize_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the most important change to review @ArthurZucker. It's the most efficient and elegant way to handle it, as we only need to traverse modules once. However, it requires to hot-patch torch.nn.Module
, which is a bummer but fine IMO.
Other options to avoid doing so all require to traverse the modules several times (at least 2 times) which is less efficient.
tests/test_modeling_common.py
Outdated
filename = inspect.getfile(model_class) | ||
# No easy way to get model addition date -> check copyright year on top of file | ||
with open(filename) as file: | ||
source_code = file.read() | ||
addition_year = 0 # if we cannot find it, set it to 0 (i.e. oldest) | ||
if match_object := re.search(r"^# Copyright (\d{4})", source_code, re.MULTILINE | re.IGNORECASE): | ||
addition_year = int(match_object.group(1)) | ||
|
||
# For now, skip everything older than 2024 and "important models" (too much models to patch otherwise) | ||
# Use `supports_cache_class` as a proxy to judge "important" models in order to prioritize them | ||
# TODO: relax this as we patch more and more models | ||
if addition_year < 2025 and not model_class._supports_cache_class: | ||
self.skipTest(reason=f"{model_class} is not a priorited model for now.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because there are too many models to patch otherwise, I'm just enforcing it for most recent/most important ones for now. Having the test like this will enforce that new models have correct init schemes, while we patch older models. Just using copyright date as a proxy of added time as it's the easiest way that came to mind
13ee1da
to
8e55a6f
Compare
6ef2c6d
to
e5d5ecc
Compare
bc884ae
to
ce665b8
Compare
run-slow: llama, mistral, mistral3, qwen2_5_vl |
This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs: models: ['models/llama', 'models/mistral', 'models/mistral3', 'models/qwen2_5_vl'] |
Slow tests are similar as main, other tests are hub timeouts and flaky test for Qwen2.5 Omni |
…site models (huggingface#37070) * Update test_modeling_common.py * Fix Llama and its modular children * Update test_modeling_common.py * qwen3 * first try at prioritizing models * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * test * fix * fix * more models * more * more * more * smarter init for composite models! * fix post rebase * smol * fix missing args * more * typo * Super elegant and efficient init for submodels * Update modeling_utils.py * style * last fixes * cleanup * finalize cleanup * CIs * improve docstring * Update modeling_utils.py * llama4 * style * CIs * style * add dpt * granite speech * qwen 2.5 omni * better fix * Parse the config file instead * CIs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, not sure how much more efficient it is, but it should be a lot! 🤗
…site models (huggingface#37070) * Update test_modeling_common.py * Fix Llama and its modular children * Update test_modeling_common.py * qwen3 * first try at prioritizing models * Update test_modeling_common.py * Update test_modeling_common.py * Update test_modeling_common.py * test * fix * fix * more models * more * more * more * smarter init for composite models! * fix post rebase * smol * fix missing args * more * typo * Super elegant and efficient init for submodels * Update modeling_utils.py * style * last fixes * cleanup * finalize cleanup * CIs * improve docstring * Update modeling_utils.py * llama4 * style * CIs * style * add dpt * granite speech * qwen 2.5 omni * better fix * Parse the config file instead * CIs
What does this PR do?
This is a follow-up of #36963.
This PR makes
_init_weights
work seamlessly with composite models. Until this point, composite models would only use the_init_weights
of the outer-mostPreTrainedModel
wrapper, leading to errors or skipped modules. Now, sub-models are correctly initialized according to their own_init_weights
, without any overhead. This is increasingly important as most recent models are now multimodal.Without this change, every composite model would have to recurse a second time on all sub-models explicitly in the outer-most
_init_weights
, which is extremely error prone and inefficient. E.g., we would need to do one or the other of the following in the outer-most_init_weights
:This PR allows to simply do
and have all submodels correctly initialized automatically.
Also, enforce
torch.no_grad()
for initialization, which was not the case before and would slow down the process.Finally, fix the
_init_weights
of a LOT of models, the most important ones (the most recent ones, and the ones with the flag_supports_cache_class=True
) for now.The reason not to do them all is simply that there are too much to fix. Almost all models in the library have broken
_ init_weights
🙃We'll patch incrementally. In the meantime, the added test will enforce that new models are correct.