1111 FusedRMSNorm ,
1212 Linear1D_Col ,
1313 Linear1D_Row ,
14+ PaddingEmbedding ,
15+ PaddingLMHead ,
1416 RMSNorm ,
1517 VocabParallelEmbedding1D ,
1618 VocabParallelLMHead1D ,
@@ -35,22 +37,18 @@ def config_sanity_check(self):
3537 pass
3638
3739 def preprocess (self ):
38- # reshape the embedding layer
39- r"""
40- Reshape the Embedding layer to make the embedding dimension divisible by world_size
41- """
42- # TODO padding the vocab size in VocabParallelEmbedding1D
43- # vocab_size = self.model.config.vocab_size
44- # if self.shard_config.enable_tensor_parallelism:
45- # world_size = self.shard_config.tensor_parallel_size
46- # multiple = world_size * self.shard_config.make_vocab_size_divisible_by
47- # else:
48- # multiple = self.shard_config.make_vocab_size_divisible_by
49- # if vocab_size % multiple != 0:
50- # new_vocab_size = vocab_size + multiple - vocab_size % multiple
51- # self.model.resize_token_embeddings(new_vocab_size)
40+ self .tie_weight = self .tie_weight_check ()
5241 return self .model
5342
43+ def tie_weight_check (self ):
44+ input_embedding = self .model .get_input_embeddings ()
45+ output_embedding = self .model .get_output_embeddings ()
46+ return (
47+ input_embedding is not None
48+ and output_embedding is not None
49+ and id (input_embedding .weight ) == id (output_embedding .weight )
50+ )
51+
5452 def module_policy (self ):
5553 from transformers .models .t5 .modeling_t5 import (
5654 T5Attention ,
@@ -64,6 +62,13 @@ def module_policy(self):
6462
6563 policy = {}
6664
65+ embedding_cls = None
66+ if self .shard_config .enable_tensor_parallelism :
67+ embedding_cls = VocabParallelEmbedding1D
68+ else :
69+ if self .tie_weight :
70+ embedding_cls = PaddingEmbedding
71+
6772 if self .shard_config .enable_fused_normalization :
6873 norm_cls = FusedRMSNorm
6974 else :
@@ -80,11 +85,6 @@ def module_policy(self):
8085 suffix = "dropout" ,
8186 target_module = DropoutForParallelInput ,
8287 ),
83- SubModuleReplacementDescription (
84- suffix = "embed_tokens" ,
85- target_module = VocabParallelEmbedding1D ,
86- kwargs = {"make_vocab_size_divisible_by" : self .shard_config .make_vocab_size_divisible_by },
87- ),
8888 ]
8989 )
9090 policy [T5LayerSelfAttention ] = ModulePolicyDescription (
@@ -180,6 +180,17 @@ def module_policy(self):
180180 ]
181181 )
182182
183+ if embedding_cls is not None :
184+ self .append_or_create_submodule_replacement (
185+ description = SubModuleReplacementDescription (
186+ suffix = "embed_tokens" ,
187+ target_module = embedding_cls ,
188+ kwargs = {"make_vocab_size_divisible_by" : self .shard_config .make_vocab_size_divisible_by },
189+ ),
190+ policy = policy ,
191+ target_key = T5Stack ,
192+ )
193+
183194 # optimization configuration
184195 self .append_or_create_submodule_replacement (
185196 description = SubModuleReplacementDescription (
@@ -371,11 +382,18 @@ def module_policy(self):
371382
372383 policy = super ().module_policy ()
373384
385+ embedding_cls = None
374386 if self .shard_config .enable_tensor_parallelism :
387+ embedding_cls = VocabParallelEmbedding1D
388+ else :
389+ if self .tie_weight :
390+ embedding_cls = PaddingEmbedding
391+
392+ if embedding_cls is not None :
375393 self .append_or_create_submodule_replacement (
376394 description = SubModuleReplacementDescription (
377395 suffix = "shared" ,
378- target_module = VocabParallelEmbedding1D ,
396+ target_module = embedding_cls ,
379397 kwargs = {"make_vocab_size_divisible_by" : self .shard_config .make_vocab_size_divisible_by },
380398 ),
381399 policy = policy ,
@@ -408,23 +426,44 @@ def module_policy(self):
408426
409427 policy = super ().module_policy ()
410428
429+ embedding_cls = None
411430 if self .shard_config .enable_tensor_parallelism :
431+ embedding_cls = VocabParallelEmbedding1D
432+ else :
433+ if self .tie_weight :
434+ embedding_cls = PaddingEmbedding
435+
436+ if embedding_cls is not None :
412437 self .append_or_create_submodule_replacement (
413- description = [
414- SubModuleReplacementDescription (
415- suffix = "shared" ,
416- target_module = VocabParallelEmbedding1D ,
417- kwargs = {"make_vocab_size_divisible_by" : self .shard_config .make_vocab_size_divisible_by },
418- ),
419- SubModuleReplacementDescription (
420- suffix = "lm_head" ,
421- target_module = VocabParallelLMHead1D ,
422- kwargs = dict (
423- gather_output = True ,
424- make_vocab_size_divisible_by = self .shard_config .make_vocab_size_divisible_by ,
425- ),
426- ),
427- ],
438+ description = SubModuleReplacementDescription (
439+ suffix = "shared" ,
440+ target_module = embedding_cls ,
441+ kwargs = {"make_vocab_size_divisible_by" : self .shard_config .make_vocab_size_divisible_by },
442+ ),
443+ policy = policy ,
444+ target_key = T5ForConditionalGeneration ,
445+ )
446+
447+ if self .shard_config .enable_tensor_parallelism :
448+ self .append_or_create_submodule_replacement (
449+ description = SubModuleReplacementDescription (
450+ suffix = "lm_head" ,
451+ target_module = VocabParallelLMHead1D ,
452+ kwargs = {
453+ "gather_output" : True ,
454+ "make_vocab_size_divisible_by" : self .shard_config .make_vocab_size_divisible_by ,
455+ },
456+ ),
457+ policy = policy ,
458+ target_key = T5ForConditionalGeneration ,
459+ )
460+ else :
461+ self .append_or_create_submodule_replacement (
462+ description = SubModuleReplacementDescription (
463+ suffix = "lm_head" ,
464+ target_module = PaddingLMHead ,
465+ kwargs = {"make_vocab_size_divisible_by" : self .shard_config .make_vocab_size_divisible_by },
466+ ),
428467 policy = policy ,
429468 target_key = T5ForConditionalGeneration ,
430469 )
@@ -475,11 +514,18 @@ def module_policy(self):
475514
476515 policy = super ().module_policy ()
477516
517+ embedding_cls = None
478518 if self .shard_config .enable_tensor_parallelism :
519+ embedding_cls = VocabParallelEmbedding1D
520+ else :
521+ if self .tie_weight :
522+ embedding_cls = PaddingEmbedding
523+
524+ if embedding_cls is not None :
479525 self .append_or_create_submodule_replacement (
480526 description = SubModuleReplacementDescription (
481527 suffix = "shared" ,
482- target_module = VocabParallelEmbedding1D ,
528+ target_module = embedding_cls ,
483529 kwargs = {"make_vocab_size_divisible_by" : self .shard_config .make_vocab_size_divisible_by },
484530 ),
485531 policy = policy ,
0 commit comments