44from functools import reduce
55from pathlib import Path
66from shutil import rmtree
7- from typing import Dict , Iterator , Optional , OrderedDict , Set , Tuple
7+ from typing import Dict , Iterator , Optional , OrderedDict , Tuple
88
99import torch
1010import torch .distributed as dist
1414
1515from colossalai .cluster import DistCoordinator
1616from colossalai .interface import ModelWrapper , OptimizerWrapper
17+ from colossalai .tensor .p_tensor import init_as_ptensor , is_padded_tensor , to_padded_tensor , to_unpadded_tensor
1718from colossalai .utils import get_current_device
1819
1920from .general_checkpoint_io import GeneralCheckpointIO
@@ -77,39 +78,39 @@ def __init__(
7778 self .verbose = verbose
7879 self .coordinator = DistCoordinator ()
7980
80- @staticmethod
81- def _named_modules (
82- module : nn .Module , memo : Optional [Set [nn .Module ]] = None , prefix : str = "" , remove_duplicate : bool = True
83- ):
84- r"""Returns an iterator over all leaf modules in the network, yielding
85- both the name of the module as well as the module itself.
86-
87- Args:
88- memo: a memo to store the set of modules already added to the result
89- prefix: a prefix that will be added to the name of the module
90- remove_duplicate: whether to remove the duplicated module instances in the result
91- or not
92-
93- Yields:
94- (str, Module): Tuple of name and module
95-
96- Note:
97- Duplicate modules are returned only once. In the following
98- example, ``l`` will be returned only once.
99- """
100- if memo is None :
101- memo = set ()
102-
103- if module not in memo :
104- sub_modules = [(name , subm ) for (name , subm ) in module ._modules .items () if subm is not None ]
105- if len (sub_modules ) == 0 :
106- if remove_duplicate :
107- memo .add (module )
108- yield prefix , module
109- else :
110- for name , subm in sub_modules :
111- submodule_prefix = prefix + ("." if prefix else "" ) + name
112- yield from HybridParallelCheckpointIO ._named_modules (subm , memo , submodule_prefix , remove_duplicate )
81+ # @staticmethod
82+ # def _named_modules(
83+ # module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
84+ # ):
85+ # r"""Returns an iterator over all leaf modules in the network, yielding
86+ # both the name of the module as well as the module itself.
87+
88+ # Args:
89+ # memo: a memo to store the set of modules already added to the result
90+ # prefix: a prefix that will be added to the name of the module
91+ # remove_duplicate: whether to remove the duplicated module instances in the result
92+ # or not
93+
94+ # Yields:
95+ # (str, Module): Tuple of name and module
96+
97+ # Note:
98+ # Duplicate modules are returned only once. In the following
99+ # example, ``l`` will be returned only once.
100+ # """
101+ # if memo is None:
102+ # memo = set()
103+
104+ # if module not in memo:
105+ # sub_modules = [(name, subm) for (name, subm) in module._modules.items() if subm is not None]
106+ # if len(sub_modules) == 0:
107+ # if remove_duplicate:
108+ # memo.add(module)
109+ # yield prefix, module
110+ # else:
111+ # for name, subm in sub_modules:
112+ # submodule_prefix = prefix + ("." if prefix else "") + name
113+ # yield from HybridParallelCheckpointIO._named_modules(subm, memo, submodule_prefix, remove_duplicate)
113114
114115 @staticmethod
115116 def _model_sharder (
@@ -120,18 +121,29 @@ def _model_sharder(
120121 state_dict_sharder = StateDictSharder (size_per_shard )
121122
122123 # Save parameters.
123- for module_name , module in HybridParallelCheckpointIO ._named_modules (model ):
124- state_dicts = module .state_dict ()
125- for name , param in state_dicts .items ():
126- if param is None :
127- continue
128- # Gather tensor pieces when using tensor parallel.
129- param_ = gather_distributed_param (param , keep_vars = False )
130- if module_name != "" :
131- module_name = module_name + "."
132- block , block_size = state_dict_sharder .append_param (module_name + name , param_ )
133- if block is not None :
134- yield block , block_size
124+ # for module_name, module in HybridParallelCheckpointIO._named_modules(model):
125+ # state_dicts = module.state_dict()
126+ # for name, param in state_dicts.items():
127+ # if param is None:
128+ # continue
129+ # # Gather tensor pieces when using tensor parallel.
130+ # param_ = gather_distributed_param(param, keep_vars=False)
131+ # if module_name != "":
132+ # module_name = module_name + "."
133+ # block, block_size = state_dict_sharder.append_param(module_name + name, param_)
134+ # if block is not None:
135+ # yield block, block_size
136+ for name , param in model .named_parameters ():
137+ if param is None :
138+ continue
139+ # Gather tensor pieces when using tensor parallel.
140+ if is_padded_tensor (param ):
141+ print ("bbbbbbbbbbbbbbbbbbbbbbbbbb" )
142+ param = to_unpadded_tensor (param )
143+ param_ = gather_distributed_param (param , keep_vars = False )
144+ block , block_size = state_dict_sharder .append_param (prefix + name , param_ )
145+ if block is not None :
146+ yield block , block_size
135147
136148 # Save buffers.
137149 for name , buf in model .named_buffers ():
@@ -906,7 +918,13 @@ def gather_from_sharded_optimizer_state(
906918 dist .all_gather (gather_tensor , v , group = tp_group )
907919 v = torch .cat (gather_tensor , dim = partition_dim )
908920
909- state_ [k ] = v .detach ().clone ()[: original_shape [0 ], ...].to (device )
921+ padding_dim = search_padding_dim (v .shape , original_shape )
922+ if padding_dim is not None :
923+ print ("cccccccccccec" )
924+ v = init_as_ptensor (v , v .shape [padding_dim ], original_shape [padding_dim ], padding_dim )
925+ v = to_unpadded_tensor (v )
926+
927+ state_ [k ] = v .detach ().clone ().to (device )
910928
911929 return state_
912930
@@ -949,15 +967,17 @@ def shard_from_complete_optimizer_state(
949967
950968 padding_dim = search_padding_dim (global_shape , original_shape )
951969 if padding_dim is not None :
952- padding_size = global_shape [padding_dim ] - original_shape [padding_dim ]
953- padding_data = torch .zeros (
954- * v .shape [:padding_dim ],
955- padding_size ,
956- * v .shape [padding_dim + 1 :],
957- device = v .device ,
958- dtype = v .dtype ,
959- )
960- v = torch .cat ((v , padding_data ), dim = padding_dim ).contiguous ()
970+ # padding_size = global_shape[padding_dim] - original_shape[padding_dim]
971+ # padding_data = torch.zeros(
972+ # *v.shape[:padding_dim],
973+ # padding_size,
974+ # *v.shape[padding_dim + 1 :],
975+ # device=v.device,
976+ # dtype=v.dtype,
977+ # )
978+ # v = torch.cat((v, padding_data), dim=padding_dim).contiguous()
979+ print ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" )
980+ v = to_padded_tensor (v , global_shape [padding_dim ], padding_dim )
961981
962982 if partition_dim is not None :
963983 slice_size = current_shape [partition_dim ]
0 commit comments