@@ -1935,7 +1935,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
19351935 else :
19361936 # We load the model state dict on the CPU to avoid an OOM error.
19371937 state_dict = torch .load (os .path .join (resume_from_checkpoint , WEIGHTS_NAME ), map_location = "cpu" )
1938- load_result = model .load_state_dict (state_dict , strict = False )
1938+ # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
1939+ # which takes *args instead of **kwargs
1940+ load_result = model .load_state_dict (state_dict , False )
19391941 # release memory
19401942 del state_dict
19411943 self ._issue_warnings_after_load (load_result )
@@ -1989,7 +1991,9 @@ def _load_best_model(self):
19891991 # We load the model state dict on the CPU to avoid an OOM error.
19901992 state_dict = torch .load (best_model_path , map_location = "cpu" )
19911993 # If the model is on the GPU, it still works!
1992- load_result = model .load_state_dict (state_dict , strict = False )
1994+ # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
1995+ # which takes *args instead of **kwargs
1996+ load_result = model .load_state_dict (state_dict , False )
19931997 if not is_sagemaker_mp_enabled ():
19941998 self ._issue_warnings_after_load (load_result )
19951999 elif os .path .exists (os .path .join (self .state .best_model_checkpoint , WEIGHTS_INDEX_NAME )):
0 commit comments