Skip to content
32 changes: 24 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,9 @@ def shard_checkpoint(

weight_size = weight.numel() * dtype_byte_size(weight.dtype)

# If this weight is going to tip up over the maximal size, we split.
if last_block_size + weight_size > max_shard_size:
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
# weight in the current shard.
if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is necessary to make sure we save something in the first shard. With the removal of position_ids from the tensors saved, some tests of checkpoint sharding with BERT started failing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth adding a comment mentioning why we're doing this!

sharded_state_dicts.append({})
last_block_size = 0

Expand Down Expand Up @@ -3044,15 +3045,30 @@ def _fix_key(key):
expected_keys = [".".join([prefix, s]) for s in expected_keys]

missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys)
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
# buffers
model_buffers = {n for n, _ in model.named_buffers()}
if remove_prefix_from_model:
model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers}
elif add_prefix_to_model:
model_buffers = {".".join([prefix, key]) for key in model_buffers}
unexpected_keys = list(unexpected_keys - model_buffers)

if is_accelerate_available():
model.tie_weights()
tied_params = find_tied_parameters(model)
else:
tied_params = []
model.tie_weights()
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor) if tensor.device != torch.device("meta") else id(tensor)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accelerate detects tied weights on IDs but it doesn't work for all models (deformable_detr for instance). So we use the same test as elsewhere except when the tensor is on the meta device (in which case it fails) where we default to id.

ptrs[id_tensor].append(name)

# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]

for group in tied_params:
if remove_prefix_from_model:
group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group]
elif add_prefix_to_model:
group = [".".join([prefix, key]) for key in group]
missing_in_group = [k for k in missing_keys if k in group]
if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
missing_keys = [k for k in missing_keys if k not in missing_in_group]
Expand Down
20 changes: 3 additions & 17 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def __init__(self, config: AlbertConfig):
self.dropout = nn.Dropout(config.hidden_dropout_prob)

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
Expand Down Expand Up @@ -507,7 +509,6 @@ class AlbertPreTrainedModel(PreTrainedModel):
config_class = AlbertConfig
load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert"
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
"""Initialize the weights."""
Expand Down Expand Up @@ -760,11 +761,6 @@ def forward(
)
class AlbertForPreTraining(AlbertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
"embeddings.position_ids",
]

def __init__(self, config: AlbertConfig):
super().__init__(config)
Expand Down Expand Up @@ -912,13 +908,7 @@ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
ALBERT_START_DOCSTRING,
)
class AlbertForMaskedLM(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_keys_to_ignore_on_load_missing = [
"predictions.decoder.weight",
"predictions.decoder.bias",
"embeddings.position_ids",
]

def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -1133,8 +1123,6 @@ def forward(
ALBERT_START_DOCSTRING,
)
class AlbertForTokenClassification(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]

def __init__(self, config: AlbertConfig):
super().__init__(config)
self.num_labels = config.num_labels
Expand Down Expand Up @@ -1218,8 +1206,6 @@ def forward(
ALBERT_START_DOCSTRING,
)
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]

def __init__(self, config: AlbertConfig):
super().__init__(config)
self.num_labels = config.num_labels
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,9 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
Expand Down Expand Up @@ -1176,7 +1178,6 @@ class AlignPreTrainedModel(PreTrainedModel):
config_class = AlignConfig
base_model_prefix = "align"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
Expand Down Expand Up @@ -1016,7 +1018,7 @@ def __init__(self, config: AltCLIPVisionConfig):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
Expand All @@ -1038,7 +1040,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
config_class = AltCLIPConfig
base_model_prefix = "altclip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
13 changes: 2 additions & 11 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ class BartPretrainedModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
_keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values"

Expand Down Expand Up @@ -1170,7 +1170,6 @@ def custom_forward(*inputs):
BART_START_DOCSTRING,
)
class BartModel(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

def __init__(self, config: BartConfig):
Expand Down Expand Up @@ -1300,12 +1299,7 @@ def forward(
class BartForConditionalGeneration(BartPretrainedModel):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = [
"final_logits_bias",
"lm_head.weight",
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]

def __init__(self, config: BartConfig):
super().__init__(config)
Expand Down Expand Up @@ -1478,7 +1472,6 @@ def _reorder_cache(past_key_values, beam_idx):
BART_START_DOCSTRING,
)
class BartForSequenceClassification(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

def __init__(self, config: BartConfig, **kwargs):
Expand Down Expand Up @@ -1609,7 +1602,6 @@ def forward(
BART_START_DOCSTRING,
)
class BartForQuestionAnswering(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

def __init__(self, config):
Expand Down Expand Up @@ -1748,7 +1740,6 @@ def forward(self, *args, **kwargs):
BART_START_DOCSTRING,
)
class BartForCausalLM(BartPretrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def __init__(self, config: BeitConfig, window_size: tuple) -> None:
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1

self.register_buffer("relative_position_index", relative_position_index)
self.register_buffer("relative_position_index", relative_position_index, persistent=False)

def forward(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
Expand Down
14 changes: 3 additions & 11 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
Expand Down Expand Up @@ -743,7 +745,6 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down Expand Up @@ -1053,7 +1054,6 @@ def forward(
BERT_START_DOCSTRING,
)
class BertForPreTraining(BertPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]

def __init__(self, config):
Expand Down Expand Up @@ -1160,8 +1160,6 @@ def forward(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]

def __init__(self, config):
Expand Down Expand Up @@ -1301,8 +1299,6 @@ def _reorder_cache(self, past_key_values, beam_idx):

@add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]

def __init__(self, config):
Expand Down Expand Up @@ -1715,8 +1711,6 @@ def forward(
BERT_START_DOCSTRING,
)
class BertForTokenClassification(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
Expand Down Expand Up @@ -1800,8 +1794,6 @@ def forward(
BERT_START_DOCSTRING,
)
class BertForQuestionAnswering(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,9 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout_prob)

# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)

def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
if input_ids is not None:
Expand Down Expand Up @@ -588,7 +590,6 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
config_class = BertGenerationConfig
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down Expand Up @@ -860,7 +861,6 @@ def _tie_weights(self):
BERT_GENERATION_START_DOCSTRING,
)
class BertGenerationDecoder(BertGenerationPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.decoder.weight", "lm_head.decoder.bias", "embeddings.position_ids"]
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]

def __init__(self, config):
Expand Down
13 changes: 3 additions & 10 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.register_buffer(
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
Expand Down Expand Up @@ -1765,7 +1767,6 @@ class BigBirdPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_big_bird
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down Expand Up @@ -2261,7 +2262,6 @@ def _pad_to_block_size(


class BigBirdForPreTraining(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

def __init__(self, config):
Expand Down Expand Up @@ -2368,7 +2368,6 @@ def forward(

@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING)
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

def __init__(self, config):
Expand Down Expand Up @@ -2513,12 +2512,6 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_
"""BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING
)
class BigBirdForCausalLM(BigBirdPreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"predictions.decoder.bias",
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
]
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

def __init__(self, config):
Expand Down
Loading