Skip to content

GPTNeoForCausalLM: resuming Trainer from checkpoint causes Missing key(s) in state_dict: "lm_head.weight" #11666

@xusky69

Description

@xusky69

Environment info

  • transformers version: 4.6.0.dev0 (also happens with pip 4.5.1)
  • Platform: Linux-4.19.112+-x86_64-with-Ubuntu-18.04-bionic (Google Colab)
  • Python version: 3.7.10
  • PyTorch version (GPU?): 1.8.1+cu101 (True)
  • Tensorflow version (GPU?): Not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help

Information

Resuming training from a Trainer checkpoint for GPTNeoForCausalLM causes the following runtime error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-14-3b03205cdcc2> in <module>()
      2 ### %%%%%%%%%%%%%%%%%%%%%%%% TRAINING %%%%%%%%%%%%%%%%%%%%%%%%% ###
      3 ### %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ###
----> 4 trainer.train(checkpoint)

1 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1222         if len(error_msgs) > 0:
   1223             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1224                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1225         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1226 

RuntimeError: Error(s) in loading state_dict for GPTNeoForCausalLM:
	Missing key(s) in state_dict: "lm_head.weight". 

This happens with the 125M model, havent tested with 1.3b an 2.7b. Loadding the model manually using .from_pretrained() and commenting the following lines in /transformers/trainer.py

             else:
                # We load the model state dict on the CPU to avoid an OOM error.
                state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
                # If the model is on the GPU, it still works!
                self.model.load_state_dict(state_dict)

Allows me to resume training.

To reproduce

Steps to reproduce the behavior:

  1. Initialize training via Trainer for GPTNeoForCausalLM and save a checkpoint
  2. Reset env and try to resume training from such checkpoint

Expected behavior

For the training to resume correctly

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions