Skip to content

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Mar 25, 2025

What does this PR do?

This PR removes the now useless _fast_init and low_cpu_mem_usage in from_pretrained, in order to simplify even more and limit the number of code paths, in the end making it much easier to maintain/debug. These 2 parameters should always be True anyway for optimized model loading.

Because a LOT of models have bad _init_weights() methods (i.e. it does not init ALL parameters), it might be an issue if loading corrupted state dict (i.e. loading a state dict with missing weight, and one of the missing weight not being handled by _init_weights properly). However, this should not be an issue in general as we don't expect to have too many corrupted state dicts on the hub. Moreover, this bug is ALREADY PRESENT whenever loading such a model with a device_map, or low_cpu_mem_usage=True (or whatever option ending in activating low_cpu_mem_usage=True). This is because doing so will force to load the parameters on meta, so weights initialized in the __init__ of a Layer or similar (which assumes instantiating the model on cpu) will result in wrong weight init when moving back to cpu.

Nevertheless, it can be hard to debug, and should not be the case, so this PR already fixes some model's _init_weights. Jointly, #37070 adds a test to always detect if a model's _init_weights is missing a few parameters, and I will fix more models directly in it (it relies on the fact that _fast_init and low_cpu_mem_usage are already gone).
Fun fact: even our faithful Llama has a bad _init_weights!! (missing the RMSNorm) 🤯

Most of the files changed are simply removing old _fast_init tests (which were skipped anyway 🙃🙃), as well as fixing weight initialization for a few models that were blocking general CI tests.

Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers.

@github-actions github-actions bot marked this pull request as draft March 25, 2025 13:58
@Cyrilvallez Cyrilvallez marked this pull request as ready for review March 25, 2025 13:59
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sssshhhhhh
Copy link

Hi this is causing weights to all be meta tensors when from_flax=True for at least whisper and bert. This was already broken with low_cpu_mem_usage before so not unexpected I guess.

from transformers import BertModel
model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)  # also 'openai/whisper-tiny'
assert model.state_dict()['embeddings.word_embeddings.weight'].is_meta

I'm only using this to convert from jax so not a big deal to use an older version. Sorry if you're already aware of this issue.

@Cyrilvallez
Copy link
Member Author

Hey @sssshhhhhh! We were actually not aware of it, it got through the radars as from_flax/from_tf are extremely rarely used! So rarely that apparently nobody ever reported that it was broken with a device_map (which implicitly used to activate low_cpu_mem_usage) 😳 I must say I'm quite surprised by this as this was in the codebase for quite a long time.

However, we will very soon start to deprecate flax and tf. As loading from_flax/from_tf uses the model architecture in the underlying library, it means that we will also stop supporting the from_flax/from_tf flags in from_pretrained. As a result, I don't think loading with these flags will be fixed in the current library state (main).

As a result, I do think the easiest is indeed to use older version to convert to pytorch if needed, then resave them. Would that be an acceptable way to proceed for your use-case? Or would that provide too much friction/disconfort?

@farzadab
Copy link

farzadab commented Apr 14, 2025

@Cyrilvallez I spent a lot of time trying to figure this out but I'm still left with no good solution.

Changing the behaviour of _init_weights as you suggested, cannot solve this issue because it assumes each inner module can be initialized separately (since it's called with model.apply), but that's not what I want. I want to be able to load sub-models (e.g. language_model and audio_tower) from checkpoints (e.g. HF Hub).

The best I can think of is to overwrite _load_pretrained_model, then somehow find the checkpoints for the sub-models and call language_model._load_pretrained_model on them.
What makes this extremely hard is that replicating the same behaviour for finding the checkpoints as .from_pretrained is very hard since from_pretrained is not very modular (1300 lines).

@Cyrilvallez
Copy link
Member Author

Hey! Yes, repos are expected to contain all their weights, so things would be much simpler if you added all weights to your repo directly (i.e. the weights of the submodels). However, if you don't want to do that, I believe _init_weights can still be used with something along the lines of:

def _init_weights(self, module):
        
    if module is self.language_model:
        self.language_model = module.from_pretrained(...)
    elif module in self.language_model.modules():
        pass
    ....

but passing specific args (i.e. same args as the outer call) to that inner from_pretrained will require a bit more hack

from_pretrained should be much less lines now as well, we simplified a lot 🤗 Are you sure you're looking at main?

Hope this solves your issue 🤗

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* Remove low_cpu_mem_usage and _fast_init

* Update deepspeed.py

* Update modeling_utils.py

* remove the first 2 tests everywhere

* Update test_modeling_common.py

* remove what was remaining about fast_init

* fix logic and simplify

* mismatched keys logic update

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* fix 2 models init_weights

* extend to others

* remove grad

* Update modeling_fsmt.py

* init weights in tests

* style

* Update test_modeling_fsmt.py

* more old models

* fix more init_weights

* copies

* fix

* style

* Update modeling_lxmert.py

* fix inits

* more and more

* more

* should finalize

* style

* Update modeling_dinov2_with_registers.py

* fix

* Update modeling_encoder_decoder.py

* fix

* style

* Update modeling_lxmert.py

* post rebase cleanup

* Update modeling_informer.py

* back to start for device

* fix

* add test to detect all failing cases correctly

* Update test_modeling_common.py

* fix

* fix

* sam

* style

* Update modeling_maskformer_swin.py

* CIs

* CIs

* remove test - will add it on separate PR

* fix

* fix

* Update modeling_sam.py

* CIs

* CIs

* CIs

* convnext

* suggestions

* CIs

* fix copies after merge

---------

Co-authored-by: Yih-Dar <[email protected]>
soghomon-b pushed a commit to soghomon-b/transformers that referenced this pull request Aug 24, 2025
* Remove low_cpu_mem_usage and _fast_init

* Update deepspeed.py

* Update modeling_utils.py

* remove the first 2 tests everywhere

* Update test_modeling_common.py

* remove what was remaining about fast_init

* fix logic and simplify

* mismatched keys logic update

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* fix 2 models init_weights

* extend to others

* remove grad

* Update modeling_fsmt.py

* init weights in tests

* style

* Update test_modeling_fsmt.py

* more old models

* fix more init_weights

* copies

* fix

* style

* Update modeling_lxmert.py

* fix inits

* more and more

* more

* should finalize

* style

* Update modeling_dinov2_with_registers.py

* fix

* Update modeling_encoder_decoder.py

* fix

* style

* Update modeling_lxmert.py

* post rebase cleanup

* Update modeling_informer.py

* back to start for device

* fix

* add test to detect all failing cases correctly

* Update test_modeling_common.py

* fix

* fix

* sam

* style

* Update modeling_maskformer_swin.py

* CIs

* CIs

* remove test - will add it on separate PR

* fix

* fix

* Update modeling_sam.py

* CIs

* CIs

* CIs

* convnext

* suggestions

* CIs

* fix copies after merge

---------

Co-authored-by: Yih-Dar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants