diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index acc0df630991..3eca3b0a32f8 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -275,12 +275,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel): config_class = InstructBlipConfig base_model_prefix = "blip" supports_gradient_checkpointing = True - _keys_to_ignore_on_load_missing = [ - r"position_ids", - r"language_model.encoder.embed_tokens.weight", - r"language_model.decoder.embed_tokens.weight", - r"language_model.lm_head.weight", - ] _no_split_modules = ["InstructBlipAttention", "InstructBlipQFormerMultiHeadAttention"] _keep_in_fp32_modules = [] @@ -1011,7 +1005,9 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.config = config diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index 145238c6bfd6..2af7c4c6178d 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -176,6 +176,14 @@ def test_tie_model_weights(self): def test_tied_model_weights_key_ignore(self): pass + @unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone") + def test_load_save_without_tied_weights(self): + pass + + @unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone") + def test_model_weights_reload_no_missing_tied_weights(self): + pass + @unittest.skip("TimmBackbone doesn't have hidden size info in its configuration.") def test_channels(self): pass