@@ -518,12 +518,15 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
518518 p_mapping = param_to_save_data
519519 for name , param in self .name2param .items ():
520520 if param is not None :
521- origin_shape = self .params_info ["name2shape" ][name ]
522521 if is_ddp_ignored (param ):
523522 # deal with ddp ignored parameters
524523 destination [prefix + name ] = param if keep_vars else param .detach ()
525524 else :
526- destination [prefix + name ] = p_mapping [param ][: origin_shape [0 ], ...]
525+ if self .params_info is not None :
526+ origin_shape = self .params_info ["name2shape" ][name ]
527+ destination [prefix + name ] = p_mapping [param ][: origin_shape [0 ], ...]
528+ else :
529+ destination [prefix + name ] = p_mapping [param ]
527530 del p_mapping
528531 del param_to_save_data
529532
@@ -891,8 +894,10 @@ def state_dict_shard(
891894 gathered_param_buffer .update (self ._get_chunk_to_save_data (chunk , only_rank_0 ))
892895 gathered_param = gathered_param_buffer .pop (param_to_save )
893896
894- origin_shape = self .params_info ["name2shape" ][name ]
895- gathered_param = gathered_param [: origin_shape [0 ], ...]
897+ if self .params_info is not None :
898+ origin_shape = self .params_info ["name2shape" ][name ]
899+ gathered_param = gathered_param [: origin_shape [0 ], ...]
900+
896901 block , block_size = sharder .append_param (prefix + name , gathered_param )
897902 if block is not None :
898903 yield block , block_size
0 commit comments