@@ -90,50 +90,66 @@ def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: L
9090 self .num_model_layers = num_model_layers
9191 self .num_layers_per_stage = num_layers_per_stage
9292
93- def distribute_layers (self , num_layers : int ) -> List [int ]:
93+ def distribute_layers (
94+ self , num_layers : int , num_stages : Optional [int ] = None , num_model_chunks : Optional [int ] = None
95+ ) -> List [int ]:
9496 """Divide layers into stages"""
97+ num_stages = self .num_stages if num_stages is None else num_stages
98+ num_model_chunks = (
99+ (self .num_model_chunks if self .is_interleave else 1 ) if num_model_chunks is None else num_model_chunks
100+ )
101+
95102 if self .control_distribute_layers :
96103 assert num_layers == self .num_model_layers
97104 return self .num_layers_per_stage
98105
99106 else :
100- num_model_chunk = self .num_model_chunks if self .is_interleave else 1
101- quotient = num_layers // (self .num_stages * num_model_chunk )
102- remainder = num_layers % (self .num_stages * num_model_chunk )
107+ quotient = num_layers // (num_stages * num_model_chunks )
108+ remainder = num_layers % (num_stages * num_model_chunks )
103109
104110 # calculate the num_layers per stage
105- layers_per_stage = [quotient ] * self . num_stages * num_model_chunk
111+ layers_per_stage = [quotient ] * num_stages * num_model_chunks
106112
107113 # deal with the rest layers
108114 if remainder > 0 :
109- start_position = (self . num_stages * num_model_chunk ) // 2 - remainder // 2
115+ start_position = (num_stages * num_model_chunks ) // 2 - remainder // 2
110116 for i in range (start_position , start_position + remainder ):
111117 layers_per_stage [i ] += 1
112118 return layers_per_stage
113119
114120 def get_stage_index (
115121 self ,
116122 layers_per_stage : List [int ],
123+ stage : Optional [int ] = None ,
124+ num_model_chunks : Optional [int ] = None ,
125+ num_stages : Optional [int ] = None ,
117126 ) -> Union [Tuple [int , int ], List [Tuple [int , int ]]]:
118127 """
119128 Get the start index and end index of layers for each stage.
120129
121130 Args:
122131 layers_per_stage (List[int]): number of layers for each stage
123132 stage (int): the stage index
133+ num_stages (int): number of stages
134+ num_model_chunks (int): number of model chunks
124135
125136 Returns:
126137 - Tuple[int, int]: the start index and end index of this stage
127138 - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
128139
129140 """
141+ stage = self .stage if stage is None else stage
142+ num_model_chunks = (
143+ (self .num_model_chunks if self .is_interleave else 1 ) if num_model_chunks is None else num_model_chunks
144+ )
145+ num_stages = self .num_stages if num_stages is None else num_stages
146+
130147 num_layers_per_stage_accumulated = np .insert (np .cumsum (layers_per_stage ), 0 , 0 )
131148
132149 stage_indices = []
133- num_model_chunks = self .num_model_chunks if self .is_interleave else 1
134150 for model_chunk in range (num_model_chunks ):
135- start_idx = num_layers_per_stage_accumulated [self . stage + model_chunk * self . num_stages ]
136- end_idx = num_layers_per_stage_accumulated [self . stage + model_chunk * self . num_stages + 1 ]
151+ start_idx = num_layers_per_stage_accumulated [stage + model_chunk * num_stages ]
152+ end_idx = num_layers_per_stage_accumulated [stage + model_chunk * num_stages + 1 ]
137153 stage_indices .append ([start_idx , end_idx ])
138154
139155 return stage_indices [0 ] if num_model_chunks == 1 else stage_indices
0 commit comments