Skip to content

Conversation

@sgugger
Copy link
Collaborator

@sgugger sgugger commented Jun 26, 2023

What does this PR do?

This PR finishes the work done in and completely cleans up the _keys_to_ignore_on_save, _keys_to_ignore_on_load_missing and _keys_to_ignore_on_load_unexpected. Those were used in three situations:

  1. Not saving the tied weights. This came from the (wrong) assumption that torch would take twice the space for tied weights (which it doesn't) and also created bugs where non-tied weights were not saved (unless a hack was added like for RoBERTa models). This is not necessary since PyTorch doesn't take more space for tied weights and safetensors will properly remove them (with _tied_weights_keys)

  2. Ignoring non-saved non-persistent buffers. This can be done automatically in the code of modeling_utils as non-persistent buffers are keys in the model named buffers not in the state dict, so easy to dectect

  3. Ignoring known unexpected weights from another architecture (like the pooler). This isn't necessary anymore since we don't issue a warning in this case.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 26, 2023

The documentation is not available anymore as the PR was closed or merged.


# If this weight is going to tip up over the maximal size, we split.
if last_block_size + weight_size > max_shard_size:
if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is necessary to make sure we save something in the first shard. With the removal of position_ids from the tensors saved, some tests of checkpoint sharding with BERT started failing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth adding a comment mentioning why we're doing this!

model.tie_weights()
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor) if tensor.device != torch.device("meta") else id(tensor)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accelerate detects tied weights on IDs but it doesn't work for all models (deformable_detr for instance). So we use the same test as elsewhere except when the tensor is on the meta device (in which case it fails) where we default to id.

self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))

# XXX: this might be a candidate for common tests if we have many of those
def test_lm_head_ignore_keys(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was testing the hack added to remove the weights from the _keys_to_ignore_on_save when untied.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to see it go

f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
)

def test_load_save_without_tied_weights(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new test checks that when weights are untied, they are properly saved and we complain if they are missing from the checkpoint.


def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def test_model_weights_reload_no_missing_tied_weights(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous test was mostly focused on checking tied weights were in the _keys_to_ignore_on_load_missing class variable, but we don't put them there anymore. It's thus adapted.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thanks for this change!


# If this weight is going to tip up over the maximal size, we split.
if last_block_size + weight_size > max_shard_size:
if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth adding a comment mentioning why we're doing this!

self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))

# XXX: this might be a candidate for common tests if we have many of those
def test_lm_head_ignore_keys(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to see it go

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice tidy up - thanks for working on this and updating!

self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checking this should be persistent=True. Assuming yes, given other buffers but most other models seem to have it as False

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CLAP has the persistent=False->persistent=True in its copied from statement, that's why. I don't want to break it accidentally so didn't touch that statement.


# Initialize weights and apply final processing
self.post_init()
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it still work with the removal of this buffer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah it was a duplicate (not shown by the diff). You can scroll to line 452 below to see it again defined with persistent=False.

@manav-glean
Copy link

@sgugger with this change, a few of our trained models that used the older format no longer are able to properly load unless we set strict=False because they contain embeddings.position_ids key that no longer exists. I wonder if there is a way to land this change such that it would be backwards compat with older model files as well. I see a few different issues have popped up as a result of this change and a lot of just required loading and resaving the model files but that is sometimes difficult to do at scale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants