-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Clean load keys #24505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean load keys #24505
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
|
||
| # If this weight is going to tip up over the maximal size, we split. | ||
| if last_block_size + weight_size > max_shard_size: | ||
| if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is necessary to make sure we save something in the first shard. With the removal of position_ids from the tensors saved, some tests of checkpoint sharding with BERT started failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth adding a comment mentioning why we're doing this!
| model.tie_weights() | ||
| ptrs = collections.defaultdict(list) | ||
| for name, tensor in model.state_dict().items(): | ||
| id_tensor = id_tensor_storage(tensor) if tensor.device != torch.device("meta") else id(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accelerate detects tied weights on IDs but it doesn't work for all models (deformable_detr for instance). So we use the same test as elsewhere except when the tensor is on the meta device (in which case it fails) where we default to id.
| self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4)) | ||
|
|
||
| # XXX: this might be a candidate for common tests if we have many of those | ||
| def test_lm_head_ignore_keys(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was testing the hack added to remove the weights from the _keys_to_ignore_on_save when untied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to see it go
| f"The shared pointers are incorrect, found different pointers for keys {shared_names}", | ||
| ) | ||
|
|
||
| def test_load_save_without_tied_weights(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new test checks that when weights are untied, they are properly saved and we complain if they are missing from the checkpoint.
|
|
||
| def test_tied_model_weights_key_ignore(self): | ||
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
| def test_model_weights_reload_no_missing_tied_weights(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous test was mostly focused on checking tied weights were in the _keys_to_ignore_on_load_missing class variable, but we don't put them there anymore. It's thus adapted.
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Thanks for this change!
|
|
||
| # If this weight is going to tip up over the maximal size, we split. | ||
| if last_block_size + weight_size > max_shard_size: | ||
| if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth adding a comment mentioning why we're doing this!
| self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4)) | ||
|
|
||
| # XXX: this might be a candidate for common tests if we have many of those | ||
| def test_lm_head_ignore_keys(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to see it go
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice tidy up - thanks for working on this and updating!
| self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") | ||
| self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) | ||
| self.register_buffer( | ||
| "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just double checking this should be persistent=True. Assuming yes, given other buffers but most other models seem to have it as False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CLAP has the persistent=False->persistent=True in its copied from statement, that's why. I don't want to break it accidentally so didn't touch that statement.
|
|
||
| # Initialize weights and apply final processing | ||
| self.post_init() | ||
| self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it still work with the removal of this buffer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah it was a duplicate (not shown by the diff). You can scroll to line 452 below to see it again defined with persistent=False.
* fix * fix --------- Co-authored-by: ydshieh <[email protected]>
|
@sgugger with this change, a few of our trained models that used the older format no longer are able to properly load unless we set |
What does this PR do?
This PR finishes the work done in and completely cleans up the
_keys_to_ignore_on_save,_keys_to_ignore_on_load_missingand_keys_to_ignore_on_load_unexpected. Those were used in three situations:Not saving the tied weights. This came from the (wrong) assumption that torch would take twice the space for tied weights (which it doesn't) and also created bugs where non-tied weights were not saved (unless a hack was added like for RoBERTa models). This is not necessary since PyTorch doesn't take more space for tied weights and safetensors will properly remove them (with
_tied_weights_keys)Ignoring non-saved non-persistent buffers. This can be done automatically in the code of modeling_utils as non-persistent buffers are keys in the model named buffers not in the state dict, so easy to dectect
Ignoring known unexpected weights from another architecture (like the pooler). This isn't necessary anymore since we don't issue a warning in this case.