Skip to content

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented May 11, 2021

What does this PR do?

This fixes how weights are loading when resuming training from a checkpoint, in the instances some weights are tied with other (and thus not saved). It also adds a test in the common tests to make sure the mechanism used is not broken by mistake.

Fixes #11666

@sgugger sgugger requested a review from LysandreJik May 11, 2021 15:48
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thank you for the fix!

Comment on lines +1063 to +1064
if load_result.missing_keys == self.model._keys_to_ignore_on_save:
self.model.tie_weights()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice check!

@sgugger sgugger merged commit f13f1f8 into master May 11, 2021
@sgugger sgugger deleted the test_checkpointing branch May 11, 2021 16:02
Iwontbecreative pushed a commit to Iwontbecreative/transformers that referenced this pull request Jul 15, 2021
* Add test and see where CI is unhappy

* Load with strict=False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPTNeoForCausalLM: resuming Trainer from checkpoint causes Missing key(s) in state_dict: "lm_head.weight"

3 participants