-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Description
Environment info
transformersversion: 4.6.0.dev0 (also happens with pip 4.5.1)- Platform: Linux-4.19.112+-x86_64-with-Ubuntu-18.04-bionic (Google Colab)
- Python version: 3.7.10
- PyTorch version (GPU?): 1.8.1+cu101 (True)
- Tensorflow version (GPU?): Not installed
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help
- gpt2: @patrickvonplaten, @LysandreJik
- trainer: @sgugger
Information
Resuming training from a Trainer checkpoint for GPTNeoForCausalLM causes the following runtime error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-14-3b03205cdcc2> in <module>()
2 ### %%%%%%%%%%%%%%%%%%%%%%%% TRAINING %%%%%%%%%%%%%%%%%%%%%%%%% ###
3 ### %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ###
----> 4 trainer.train(checkpoint)
1 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1222 if len(error_msgs) > 0:
1223 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1224 self.__class__.__name__, "\n\t".join(error_msgs)))
1225 return _IncompatibleKeys(missing_keys, unexpected_keys)
1226
RuntimeError: Error(s) in loading state_dict for GPTNeoForCausalLM:
Missing key(s) in state_dict: "lm_head.weight".
This happens with the 125M model, havent tested with 1.3b an 2.7b. Loadding the model manually using .from_pretrained() and commenting the following lines in /transformers/trainer.py
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self.model.load_state_dict(state_dict)
Allows me to resume training.
To reproduce
Steps to reproduce the behavior:
- Initialize training via
TrainerforGPTNeoForCausalLMand save a checkpoint - Reset env and try to resume training from such checkpoint
Expected behavior
For the training to resume correctly
Metadata
Metadata
Assignees
Labels
No labels