Skip to content

Commit 6822ee4

Browse files
committed
fix: add optional args for and
1 parent c27225a commit 6822ee4

File tree

3 files changed

+49
-21
lines changed

3 files changed

+49
-21
lines changed

colossalai/pipeline/stage_manager.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,50 +90,66 @@ def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: L
9090
self.num_model_layers = num_model_layers
9191
self.num_layers_per_stage = num_layers_per_stage
9292

93-
def distribute_layers(self, num_layers: int) -> List[int]:
93+
def distribute_layers(
94+
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
95+
) -> List[int]:
9496
"""Divide layers into stages"""
97+
num_stages = self.num_stages if num_stages is None else num_stages
98+
num_model_chunks = (
99+
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
100+
)
101+
95102
if self.control_distribute_layers:
96103
assert num_layers == self.num_model_layers
97104
return self.num_layers_per_stage
98105

99106
else:
100-
num_model_chunk = self.num_model_chunks if self.is_interleave else 1
101-
quotient = num_layers // (self.num_stages * num_model_chunk)
102-
remainder = num_layers % (self.num_stages * num_model_chunk)
107+
quotient = num_layers // (num_stages * num_model_chunks)
108+
remainder = num_layers % (num_stages * num_model_chunks)
103109

104110
# calculate the num_layers per stage
105-
layers_per_stage = [quotient] * self.num_stages * num_model_chunk
111+
layers_per_stage = [quotient] * num_stages * num_model_chunks
106112

107113
# deal with the rest layers
108114
if remainder > 0:
109-
start_position = (self.num_stages * num_model_chunk) // 2 - remainder // 2
115+
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
110116
for i in range(start_position, start_position + remainder):
111117
layers_per_stage[i] += 1
112118
return layers_per_stage
113119

114120
def get_stage_index(
115121
self,
116122
layers_per_stage: List[int],
123+
stage: Optional[int] = None,
124+
num_model_chunks: Optional[int] = None,
125+
num_stages: Optional[int] = None,
117126
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
118127
"""
119128
Get the start index and end index of layers for each stage.
120129
121130
Args:
122131
layers_per_stage (List[int]): number of layers for each stage
123132
stage (int): the stage index
133+
num_stages (int): number of stages
134+
num_model_chunks (int): number of model chunks
124135
125136
Returns:
126137
- Tuple[int, int]: the start index and end index of this stage
127138
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
128139
129140
"""
141+
stage = self.stage if stage is None else stage
142+
num_model_chunks = (
143+
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
144+
)
145+
num_stages = self.num_stages if num_stages is None else num_stages
146+
130147
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
131148

132149
stage_indices = []
133-
num_model_chunks = self.num_model_chunks if self.is_interleave else 1
134150
for model_chunk in range(num_model_chunks):
135-
start_idx = num_layers_per_stage_accumulated[self.stage + model_chunk * self.num_stages]
136-
end_idx = num_layers_per_stage_accumulated[self.stage + model_chunk * self.num_stages + 1]
151+
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
152+
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
137153
stage_indices.append([start_idx, end_idx])
138154

139155
return stage_indices[0] if num_model_chunks == 1 else stage_indices

colossalai/shardformer/policies/t5.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ def distribute_t5_layers(
251251
Return the layer distribution as a list and the starting stage of decoder.
252252
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
253253
"""
254+
stage_manager = self.pipeline_stage_manager
255+
assert stage_manager is not None, "Pipeline stage manager is not set."
254256

255257
# number of encoder layers must be a positive integer
256258
if num_encoder_layers <= 0:
@@ -262,7 +264,7 @@ def distribute_t5_layers(
262264

263265
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
264266
if num_decoder_layers == 0:
265-
return self.distribute_layers(num_encoder_layers, num_stages), num_stages
267+
return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
266268

267269
# the number of stages distributed between encoder and decoder is optimized in this way:
268270
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
@@ -273,21 +275,26 @@ def objective(num_encoder_stages):
273275
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
274276
num_decoder_stages = num_stages - num_encoder_stages
275277

276-
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
277-
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
278+
encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
279+
decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
278280
return encoder_distribution + decoder_distribution, num_encoder_stages
279281

280282
def get_t5_stage_index(
281283
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
282-
) -> Tuple[bool, int, int]:
284+
) -> Tuple[int, int]:
283285
"""
284286
Input the distribution of layers among stages, the current stage and the first stage of decoder.
285287
Return the starting/ending idx of layers in encoder/decoder
286288
"""
289+
stage_manager = self.pipeline_stage_manager
290+
assert stage_manager is not None, "Pipeline stage manager is not set."
291+
287292
if stage < decoder_starting_stage:
288-
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
293+
return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
289294
else:
290-
return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
295+
return stage_manager.get_stage_index(
296+
layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage
297+
)
291298

292299
def get_held_layers(self) -> List[nn.Module]:
293300
"""Get pipeline layers for current stage."""

colossalai/shardformer/policies/whisper.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def distribute_whisper_layers(
300300
Return the layer distribution as a list and the starting stage of decoder.
301301
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
302302
"""
303+
stage_manager = self.pipeline_stage_manager
304+
assert stage_manager is not None, "pipeline_stage_manager is None"
303305

304306
# number of encoder layers must be a positive integer
305307
if num_encoder_layers <= 0:
@@ -311,7 +313,7 @@ def distribute_whisper_layers(
311313

312314
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
313315
if num_decoder_layers == 0:
314-
return self.distribute_layers(num_encoder_layers, num_stages), num_stages
316+
return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
315317

316318
# the number of stages distributed between encoder and decoder is optimized in this way:
317319
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
@@ -322,21 +324,24 @@ def objective(num_encoder_stages):
322324
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
323325
num_decoder_stages = num_stages - num_encoder_stages
324326

325-
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
326-
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
327+
encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
328+
decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
327329
return encoder_distribution + decoder_distribution, num_encoder_stages
328330

329331
def get_whisper_stage_index(
330332
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
331-
) -> Tuple[bool, int, int]:
333+
) -> Tuple[int, int]:
332334
"""
333335
Input the distribution of layers among stages, the current stage and the first stage of decoder.
334336
Return the starting/ending idx of layers in encoder/decoder
335337
"""
338+
stage_manager = self.pipeline_stage_manager
339+
assert stage_manager is not None, "pipeline_stage_manager is None"
340+
336341
if stage < decoder_starting_stage:
337-
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
342+
return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
338343
else:
339-
return self.get_stage_index(
344+
return stage_manager.get_stage_index(
340345
layers_per_stage[decoder_starting_stage:],
341346
stage - decoder_starting_stage,
342347
)

0 commit comments

Comments
 (0)