Skip to content

Commit be6625c

Browse files
committed
Update hybrid_parallel_plugin.py
fix fix
1 parent dc95e1e commit be6625c

File tree

8 files changed

+90
-47
lines changed

8 files changed

+90
-47
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ class HybridParallelPlugin(PipelinePluginBase):
931931
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
932932
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
933933
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
934-
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 128.
934+
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
935935
"""
936936

937937
def __init__(

colossalai/shardformer/layer/embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def __init__(
192192

193193
super().__init__(self.num_embeddings, num_embeddings, weight)
194194

195-
self.resize_embedding_weight()
195+
if weight.shape[0] < self.num_embeddings:
196+
self.resize_embedding_weight()
196197

197198
if weight is None:
198199
self.reset_parameters()
@@ -306,7 +307,8 @@ def __init__(
306307

307308
# resize vocabulary size
308309
super().__init__(self.num_embeddings, num_embeddings, weight)
309-
self.resize_embedding_weight()
310+
if not is_distributed_tensor(self.weight):
311+
self.resize_embedding_weight()
310312

311313
# deal with tensor parallelism
312314
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)

colossalai/shardformer/policies/llama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ def module_policy(self):
297297
],
298298
)
299299
}
300-
print("new_item", new_item)
301300
policy.update(new_item)
302301

303302
if self.pipeline_stage_manager:

colossalai/shardformer/policies/opt.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def module_policy(self):
6565
if self.shard_config.enable_tensor_parallelism:
6666
embedding_cls = VocabParallelEmbedding1D
6767
else:
68-
# TODO when not tie weight and not pad the vocab size
6968
if self.tie_weight:
7069
embedding_cls = PaddingEmbedding
7170

colossalai/shardformer/policies/t5.py

Lines changed: 82 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
FusedRMSNorm,
1212
Linear1D_Col,
1313
Linear1D_Row,
14+
PaddingEmbedding,
15+
PaddingLMHead,
1416
RMSNorm,
1517
VocabParallelEmbedding1D,
1618
VocabParallelLMHead1D,
@@ -35,22 +37,18 @@ def config_sanity_check(self):
3537
pass
3638

3739
def preprocess(self):
38-
# reshape the embedding layer
39-
r"""
40-
Reshape the Embedding layer to make the embedding dimension divisible by world_size
41-
"""
42-
# TODO padding the vocab size in VocabParallelEmbedding1D
43-
# vocab_size = self.model.config.vocab_size
44-
# if self.shard_config.enable_tensor_parallelism:
45-
# world_size = self.shard_config.tensor_parallel_size
46-
# multiple = world_size * self.shard_config.make_vocab_size_divisible_by
47-
# else:
48-
# multiple = self.shard_config.make_vocab_size_divisible_by
49-
# if vocab_size % multiple != 0:
50-
# new_vocab_size = vocab_size + multiple - vocab_size % multiple
51-
# self.model.resize_token_embeddings(new_vocab_size)
40+
self.tie_weight = self.tie_weight_check()
5241
return self.model
5342

43+
def tie_weight_check(self):
44+
input_embedding = self.model.get_input_embeddings()
45+
output_embedding = self.model.get_output_embeddings()
46+
return (
47+
input_embedding is not None
48+
and output_embedding is not None
49+
and id(input_embedding.weight) == id(output_embedding.weight)
50+
)
51+
5452
def module_policy(self):
5553
from transformers.models.t5.modeling_t5 import (
5654
T5Attention,
@@ -64,6 +62,13 @@ def module_policy(self):
6462

6563
policy = {}
6664

65+
embedding_cls = None
66+
if self.shard_config.enable_tensor_parallelism:
67+
embedding_cls = VocabParallelEmbedding1D
68+
else:
69+
if self.tie_weight:
70+
embedding_cls = PaddingEmbedding
71+
6772
if self.shard_config.enable_fused_normalization:
6873
norm_cls = FusedRMSNorm
6974
else:
@@ -80,11 +85,6 @@ def module_policy(self):
8085
suffix="dropout",
8186
target_module=DropoutForParallelInput,
8287
),
83-
SubModuleReplacementDescription(
84-
suffix="embed_tokens",
85-
target_module=VocabParallelEmbedding1D,
86-
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
87-
),
8888
]
8989
)
9090
policy[T5LayerSelfAttention] = ModulePolicyDescription(
@@ -180,6 +180,17 @@ def module_policy(self):
180180
]
181181
)
182182

183+
if embedding_cls is not None:
184+
self.append_or_create_submodule_replacement(
185+
description=SubModuleReplacementDescription(
186+
suffix="embed_tokens",
187+
target_module=embedding_cls,
188+
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
189+
),
190+
policy=policy,
191+
target_key=T5Stack,
192+
)
193+
183194
# optimization configuration
184195
self.append_or_create_submodule_replacement(
185196
description=SubModuleReplacementDescription(
@@ -371,11 +382,18 @@ def module_policy(self):
371382

372383
policy = super().module_policy()
373384

385+
embedding_cls = None
374386
if self.shard_config.enable_tensor_parallelism:
387+
embedding_cls = VocabParallelEmbedding1D
388+
else:
389+
if self.tie_weight:
390+
embedding_cls = PaddingEmbedding
391+
392+
if embedding_cls is not None:
375393
self.append_or_create_submodule_replacement(
376394
description=SubModuleReplacementDescription(
377395
suffix="shared",
378-
target_module=VocabParallelEmbedding1D,
396+
target_module=embedding_cls,
379397
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
380398
),
381399
policy=policy,
@@ -408,23 +426,44 @@ def module_policy(self):
408426

409427
policy = super().module_policy()
410428

429+
embedding_cls = None
411430
if self.shard_config.enable_tensor_parallelism:
431+
embedding_cls = VocabParallelEmbedding1D
432+
else:
433+
if self.tie_weight:
434+
embedding_cls = PaddingEmbedding
435+
436+
if embedding_cls is not None:
412437
self.append_or_create_submodule_replacement(
413-
description=[
414-
SubModuleReplacementDescription(
415-
suffix="shared",
416-
target_module=VocabParallelEmbedding1D,
417-
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
418-
),
419-
SubModuleReplacementDescription(
420-
suffix="lm_head",
421-
target_module=VocabParallelLMHead1D,
422-
kwargs=dict(
423-
gather_output=True,
424-
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
425-
),
426-
),
427-
],
438+
description=SubModuleReplacementDescription(
439+
suffix="shared",
440+
target_module=embedding_cls,
441+
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
442+
),
443+
policy=policy,
444+
target_key=T5ForConditionalGeneration,
445+
)
446+
447+
if self.shard_config.enable_tensor_parallelism:
448+
self.append_or_create_submodule_replacement(
449+
description=SubModuleReplacementDescription(
450+
suffix="lm_head",
451+
target_module=VocabParallelLMHead1D,
452+
kwargs={
453+
"gather_output": True,
454+
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
455+
},
456+
),
457+
policy=policy,
458+
target_key=T5ForConditionalGeneration,
459+
)
460+
else:
461+
self.append_or_create_submodule_replacement(
462+
description=SubModuleReplacementDescription(
463+
suffix="lm_head",
464+
target_module=PaddingLMHead,
465+
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
466+
),
428467
policy=policy,
429468
target_key=T5ForConditionalGeneration,
430469
)
@@ -475,11 +514,18 @@ def module_policy(self):
475514

476515
policy = super().module_policy()
477516

517+
embedding_cls = None
478518
if self.shard_config.enable_tensor_parallelism:
519+
embedding_cls = VocabParallelEmbedding1D
520+
else:
521+
if self.tie_weight:
522+
embedding_cls = PaddingEmbedding
523+
524+
if embedding_cls is not None:
479525
self.append_or_create_submodule_replacement(
480526
description=SubModuleReplacementDescription(
481527
suffix="shared",
482-
target_module=VocabParallelEmbedding1D,
528+
target_module=embedding_cls,
483529
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
484530
),
485531
policy=policy,

colossalai/shardformer/shard/sharder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def shard(self) -> List[Dict[int, Tensor]]:
3939
self._preprocess()
4040
# get shared params before release unheld layers, this avoid misjudgment of shared params (None is None)
4141
shared_params = self.policy.get_shared_params()
42-
print("shared_params", shared_params)
4342
held_layers = self._release_unheld_layers()
4443
self._replace_module(include=held_layers)
4544
self._materialize()

tests/test_shardformer/test_model/test_shard_opt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
6060

6161
# optimizer executes step
6262
org_optimizer.step()
63-
# sharded_optimizer.step()
63+
sharded_optimizer.step()
6464

6565
# check last hidden state & loss
6666
if stage_manager is None or stage_manager.is_last_stage():

tests/test_shardformer/test_model/test_shard_t5.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,14 @@ def check_t5_3d(rank, world_size, port):
203203
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
204204
run_t5_3d_test()
205205

206-
# TODO padding the vocab size in VocabParallelEmbedding1D
207-
@pytest.mark.skip("padding the vocab size in VocabParallelEmbedding1D")
206+
208207
@pytest.mark.dist
209208
@rerun_if_address_is_in_use()
210209
@clear_cache_before_run()
211210
def test_t5():
212211
spawn(check_t5, 4)
213212

214-
# TODO padding the vocab size in VocabParallelEmbedding1D
215-
@pytest.mark.skip("padding the vocab size in VocabParallelEmbedding1D")
213+
216214
@pytest.mark.largedist
217215
@rerun_if_address_is_in_use()
218216
@clear_cache_before_run()

0 commit comments

Comments
 (0)