File tree Expand file tree Collapse file tree 2 files changed +19
-1
lines changed Expand file tree Collapse file tree 2 files changed +19
-1
lines changed Original file line number Diff line number Diff line change @@ -1059,7 +1059,18 @@ def train(
10591059 # We load the model state dict on the CPU to avoid an OOM error.
10601060 state_dict = torch .load (os .path .join (resume_from_checkpoint , WEIGHTS_NAME ), map_location = "cpu" )
10611061 # If the model is on the GPU, it still works!
1062- self .model .load_state_dict (state_dict )
1062+ load_result = self .model .load_state_dict (state_dict , strict = False )
1063+ if len (load_result .missing_keys ) != 0 :
1064+ if load_result .missing_keys == self .model ._keys_to_ignore_on_save :
1065+ self .model .tie_weights ()
1066+ else :
1067+ logger .warn (
1068+ f"There were missing keys in the checkpoint model loaded: { load_result .missing_keys } ."
1069+ )
1070+ if len (load_result .unexpected_keys ) != 0 :
1071+ logger .warn (
1072+ f"There were unexpected keys in the checkpoint model loaded: { load_result .unexpected_keys } ."
1073+ )
10631074
10641075 # If model was re-initialized, put it on the right device and update self.model_wrapped
10651076 if model_reloaded :
Original file line number Diff line number Diff line change @@ -177,6 +177,13 @@ def test_save_load__keys_to_ignore_on_save(self):
177177 for k in _keys_to_ignore_on_save :
178178 self .assertNotIn (k , state_dict_saved )
179179
180+ # Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
181+ load_result = model .load_state_dict (state_dict_saved , strict = False )
182+ self .assertTrue (
183+ len (load_result .missing_keys ) == 0 or load_result .missing_keys == model ._keys_to_ignore_on_save
184+ )
185+ self .assertTrue (len (load_result .unexpected_keys ) == 0 )
186+
180187 def _mock_init_weights (self , module ):
181188 if hasattr (module , "weight" ) and module .weight is not None :
182189 module .weight .data .fill_ (3 )
You can’t perform that action at this time.
0 commit comments