Skip to content

De-serializing a custom model in transformers v4.37.0 onwards doesn't work; weights aren't loaded from the saved checkpoint #29321

@ashishu007

Description

@ashishu007

System Info

  • transformers version: 4.37.0
  • Platform: Linux-5.10.192-183.736.amzn2.x86_64-x86_64-with-glibc2.17
  • Python version: 3.8.11
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): not installed (NA)
  • Tensorflow version (GPU?): 2.12.0 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: False
  • Using distributed or parallel set-up in script?: False

Who can help?

@ArthurZucker @Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I have a custom model that worked fine with transformers==4.33.1 but now fails with transformers==4.37.0. De-serializing the model from a checkpoint doesn't load the weights correctly, instead it randomly initializes the model with a long warning message. This appears to be a breaking change across minor versions that isn't backwards compatible.

A compact example of the custom model that raises the same warning is given below.

Q: What can I do to make my code work with recent transformers version?

Custom model class

import transformers
import tensorflow as tf
import numpy as np
from typing import List, Union, Tuple, Optional


class CustomBertClassifier(
    transformers.TFBertPreTrainedModel,
    transformers.modeling_tf_utils.TFSequenceClassificationLoss
):
    def __init__(self, config: transformers.BertConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = transformers.TFBertMainLayer(config, name="bert_lm")
        self.fc_layer = tf.keras.layers.Dense(units=768, name="fully_connected_layer")
        self.classifier = tf.keras.layers.Dense(
            units=2,
            name="classifier_head",
        )
        self.config = config

    def call(
        self,
        input_ids: Optional[transformers.modeling_tf_utils.TFModelInputType] = None,
        attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
        token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
        position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
        head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
        inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
        training: Optional[bool] = False,
    ) -> Union[
        transformers.modeling_tf_outputs.TFSequenceClassifierOutput,
        Tuple[tf.Tensor]
    ]:
        r"""
        labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        pooled_output = outputs[1]
        fc_output = self.fc_layer(inputs=pooled_output)
        logits = self.classifier(inputs=fc_output)
        loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return transformers.modeling_tf_outputs.TFSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

Warning Message

The following warning message appears while de-serializing the model:

Some layers from the model checkpoint at custom_model were not used when initializing CustomBertClassifier: ['bert_lm/encoder/layer_._7/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._2/output/dense/bias:0', 'bert_lm/encoder/layer_._8/attention/self/key/bias:0', 'bert_lm/encoder/layer_._9/attention/self/query/bias:0', 'bert_lm/encoder/layer_._9/output/dense/kernel:0', 'bert_lm/encoder/layer_._2/attention/self/key/bias:0', 'bert_lm/encoder/layer_._2/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._6/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._6/attention/self/value/bias:0', 'bert_lm/encoder/layer_._10/attention/self/value/bias:0', 'bert_lm/encoder/layer_._5/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._5/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._3/intermediate/dense/bias:0', 'classifier_head/bias:0', 'bert_lm/encoder/layer_._9/output/dense/bias:0', 'bert_lm/encoder/layer_._9/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._6/attention/self/query/bias:0', 'bert_lm/encoder/layer_._8/output/dense/bias:0', 'bert_lm/encoder/layer_._2/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._5/attention/self/value/bias:0', 'bert_lm/encoder/layer_._2/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._6/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._8/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._3/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._9/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._9/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._11/attention/self/query/bias:0', 'bert_lm/encoder/layer_._5/output/dense/kernel:0', 'bert_lm/encoder/layer_._3/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._11/output/dense/kernel:0', 'bert_lm/encoder/layer_._6/attention/self/key/bias:0', 'bert_lm/embeddings/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._10/attention/self/query/bias:0', 'bert_lm/encoder/layer_._4/output/dense/bias:0', 'bert_lm/encoder/layer_._3/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._1/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._6/output/dense/bias:0', 'bert_lm/encoder/layer_._1/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._4/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._7/output/dense/kernel:0', 'bert_lm/encoder/layer_._7/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._6/intermediate/dense/kernel:0', 'bert_lm/embeddings/LayerNorm/beta:0', 'bert_lm/encoder/layer_._0/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._10/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._2/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._5/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._7/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._8/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._11/intermediate/dense/kernel:0', 'fully_connected_layer/kernel:0', 'bert_lm/embeddings/token_type_embeddings/embeddings:0', 'bert_lm/encoder/layer_._10/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._0/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._3/attention/self/query/bias:0', 'bert_lm/encoder/layer_._9/attention/self/key/bias:0', 'bert_lm/encoder/layer_._4/attention/self/query/bias:0', 'bert_lm/encoder/layer_._0/attention/self/query/bias:0', 'bert_lm/encoder/layer_._10/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._1/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._9/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._7/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._10/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._11/attention/self/key/bias:0', 'bert_lm/encoder/layer_._2/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._8/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._1/attention/self/query/bias:0', 'bert_lm/encoder/layer_._5/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._5/attention/self/key/bias:0', 'bert_lm/encoder/layer_._0/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._11/attention/self/value/bias:0', 'bert_lm/encoder/layer_._0/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._1/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._3/attention/self/value/bias:0', 'bert_lm/encoder/layer_._8/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._6/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._2/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._4/output/dense/kernel:0', 'bert_lm/encoder/layer_._6/output/dense/kernel:0', 'bert_lm/encoder/layer_._9/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._0/output/dense/bias:0', 'bert_lm/encoder/layer_._8/output/dense/kernel:0', 'bert_lm/encoder/layer_._2/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._11/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._7/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._11/output/dense/bias:0', 'bert_lm/encoder/layer_._8/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._11/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._7/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._0/attention/self/value/bias:0', 'bert_lm/encoder/layer_._6/attention/output/dense/kernel:0', 'bert_lm/embeddings/position_embeddings/embeddings:0', 'bert_lm/encoder/layer_._4/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._7/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._1/output/dense/bias:0', 'bert_lm/encoder/layer_._0/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._10/output/dense/kernel:0', 'bert_lm/encoder/layer_._9/attention/self/value/bias:0', 'bert_lm/encoder/layer_._10/output/dense/bias:0', 'bert_lm/encoder/layer_._9/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._7/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._10/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._3/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._5/attention/output/dense/bias:0', 'bert_lm/pooler/dense/bias:0', 'bert_lm/encoder/layer_._0/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._2/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._3/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._7/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._1/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._0/output/dense/kernel:0', 'bert_lm/encoder/layer_._1/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._1/attention/self/value/bias:0', 'bert_lm/encoder/layer_._3/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._4/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._3/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._4/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._8/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._0/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._8/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._8/attention/self/value/bias:0', 'bert_lm/encoder/layer_._9/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._9/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._10/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._4/attention/self/value/bias:0', 'bert_lm/encoder/layer_._10/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._1/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._11/output/LayerNorm/gamma:0', 'classifier_head/kernel:0', 'bert_lm/encoder/layer_._5/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._7/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._8/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._5/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._5/attention/self/query/bias:0', 'bert_lm/encoder/layer_._1/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._11/attention/output/dense/kernel:0', 'bert_lm/pooler/dense/kernel:0', 'bert_lm/encoder/layer_._5/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._10/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._4/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._11/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._9/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._11/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._9/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._6/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._6/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._2/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._10/attention/self/key/bias:0', 'bert_lm/encoder/layer_._0/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._2/attention/self/query/kernel:0', 'bert_lm/encoder/layer_._4/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._1/output/dense/kernel:0', 'bert_lm/encoder/layer_._3/output/dense/bias:0', 'bert_lm/encoder/layer_._3/output/dense/kernel:0', 'bert_lm/encoder/layer_._4/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._8/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._10/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._2/attention/self/query/bias:0', 'bert_lm/encoder/layer_._5/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._8/attention/output/LayerNorm/gamma:0', 'bert_lm/embeddings/word_embeddings/weight:0', 'bert_lm/encoder/layer_._9/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._7/attention/self/query/bias:0', 'bert_lm/encoder/layer_._0/attention/self/key/bias:0', 'bert_lm/encoder/layer_._0/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._0/intermediate/dense/kernel:0', 'bert_lm/encoder/layer_._7/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._2/output/dense/kernel:0', 'bert_lm/encoder/layer_._8/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._5/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._1/attention/self/key/bias:0', 'bert_lm/encoder/layer_._2/attention/self/value/bias:0', 'bert_lm/encoder/layer_._6/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._7/attention/self/key/bias:0', 'bert_lm/encoder/layer_._6/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._4/attention/self/key/bias:0', 'bert_lm/encoder/layer_._7/output/dense/bias:0', 'bert_lm/encoder/layer_._7/attention/self/value/bias:0', 'bert_lm/encoder/layer_._3/attention/self/key/bias:0', 'bert_lm/encoder/layer_._10/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._11/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._1/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._1/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._3/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._11/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._0/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._4/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._11/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._4/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._3/attention/self/key/kernel:0', 'bert_lm/encoder/layer_._6/attention/self/value/kernel:0', 'bert_lm/encoder/layer_._1/attention/output/dense/kernel:0', 'bert_lm/encoder/layer_._6/intermediate/dense/bias:0', 'bert_lm/encoder/layer_._4/attention/output/LayerNorm/beta:0', 'bert_lm/encoder/layer_._8/attention/self/query/bias:0', 'bert_lm/encoder/layer_._3/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._5/output/dense/bias:0', 'bert_lm/encoder/layer_._10/output/LayerNorm/beta:0', 'fully_connected_layer/bias:0', 'bert_lm/encoder/layer_._4/attention/output/LayerNorm/gamma:0', 'bert_lm/encoder/layer_._11/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._2/attention/output/dense/bias:0', 'bert_lm/encoder/layer_._5/attention/output/LayerNorm/beta:0']
- This IS expected if you are initializing CustomBertClassifier from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CustomBertClassifier from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of CustomBertClassifier were initialized from the model checkpoint at custom_model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use CustomBertClassifier for predictions without further training.

Expected behavior

The model de-serialization should work properly without the long warning message above.

Instead of the long warning message, the following small message should appear (confirming everything is working fine):

All model checkpoint layers were used when initializing CustomBertClassifier.

All the layers of CustomBertClassifier were initialized from the model checkpoint at custom_model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use CustomBertClassifier for predictions without further training.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions