@@ -82,10 +82,9 @@ def __init__(
8282 bias_ : Optional [Parameter ] = None ,
8383 weight_initializer : Callable = init .kaiming_uniform_ (a = math .sqrt (5 )),
8484 bias_initializer : Callable = init .xavier_uniform_ (a = 1 , scale = 1 ),
85- * args ,
8685 ** kwargs ,
8786 ):
88- super ().__init__ (* args , ** kwargs )
87+ super ().__init__ (weight = weight , bias_ = bias_ , ** kwargs )
8988
9089 # Keep input parameters
9190 self .in_features = in_features
@@ -141,7 +140,7 @@ def __init__(
141140
142141 @staticmethod
143142 def from_native_module (
144- module : nn .Linear , process_group : Union [ProcessGroup , List [ProcessGroup ]], * args , * *kwargs
143+ module : nn .Linear , process_group : Union [ProcessGroup , List [ProcessGroup ]], ** kwargs
145144 ) -> ParallelModule :
146145 r"""
147146 Convert a native PyTorch linear layer to a parallelized linear layer.
@@ -174,7 +173,6 @@ def from_native_module(
174173 process_group = process_group ,
175174 weight = module .weight ,
176175 bias_ = module .bias ,
177- * args ,
178176 ** kwargs ,
179177 )
180178
@@ -316,7 +314,7 @@ def __init__(
316314
317315 @staticmethod
318316 def from_native_module (
319- module : nn .Linear , process_group : Union [ProcessGroup , List [ProcessGroup ]], * args , * *kwargs
317+ module : nn .Linear , process_group : Union [ProcessGroup , List [ProcessGroup ]], ** kwargs
320318 ) -> ParallelModule :
321319 r"""
322320 Convert a native PyTorch linear layer to a parallelized linear layer.
@@ -350,7 +348,6 @@ def from_native_module(
350348 process_group = process_group ,
351349 weight = module .weight ,
352350 bias_ = module .bias ,
353- * args ,
354351 ** kwargs ,
355352 )
356353
@@ -477,7 +474,7 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None:
477474
478475 @staticmethod
479476 def from_native_module (
480- module : nn .Linear , process_group : Union [ProcessGroup , List [ProcessGroup ]], * args , * *kwargs
477+ module : nn .Linear , process_group : Union [ProcessGroup , List [ProcessGroup ]], ** kwargs
481478 ) -> PaddingParallelModule :
482479 r"""
483480 Convert a native PyTorch linear layer to a parallelized linear layer.
@@ -489,7 +486,6 @@ def from_native_module(
489486 bias = module .bias is not None
490487 device = module .weight .device
491488 # ensure only one process group is passed
492- make_vocab_size_divisible_by = kwargs .pop ("make_vocab_size_divisible_by" , 64 )
493489
494490 lm_head_linear = PaddingLMHead (
495491 in_features = in_features ,
@@ -498,8 +494,6 @@ def from_native_module(
498494 device = device ,
499495 weight = module .weight ,
500496 bias_ = module .bias ,
501- make_vocab_size_divisible_by = make_vocab_size_divisible_by ,
502- * args ,
503497 ** kwargs ,
504498 )
505499
@@ -551,7 +545,6 @@ def __init__(
551545 weight : Optional [Parameter ] = None ,
552546 bias_ : Optional [Parameter ] = None ,
553547 make_vocab_size_divisible_by : int = 64 ,
554- * args ,
555548 ** kwargs ,
556549 ):
557550 # create weight and bias
@@ -579,12 +572,9 @@ def __init__(
579572 process_group = process_group ,
580573 weight = weight ,
581574 bias_ = bias_ ,
582- * args ,
583575 ** kwargs ,
584576 new_num_embeddings = new_out_features ,
585577 old_num_embeddings = out_features ,
586- weight_A = weight ,
587- bias_A = bias_ ,
588578 )
589579
590580 # get the length of valid embeddings
@@ -599,7 +589,7 @@ def __init__(
599589
600590 @staticmethod
601591 def from_native_module (
602- module : nn .Linear , process_group : Union [ProcessGroup , List [ProcessGroup ]], * args , * *kwargs
592+ module : nn .Linear , process_group : Union [ProcessGroup , List [ProcessGroup ]], ** kwargs
603593 ) -> PaddingParallelModule :
604594 r"""
605595 Convert a native PyTorch linear layer to a parallelized linear layer.
@@ -611,8 +601,6 @@ def from_native_module(
611601 bias = module .bias is not None
612602 device = module .weight .device
613603
614- make_vocab_size_divisible_by = kwargs .pop ("make_vocab_size_divisible_by" , 64 )
615-
616604 lm_head_linear = VocabParallelLMHead1D (
617605 in_features = in_features ,
618606 out_features = out_features ,
@@ -621,41 +609,18 @@ def from_native_module(
621609 process_group = process_group ,
622610 weight = module .weight ,
623611 bias_ = module .bias ,
624- make_vocab_size_divisible_by = make_vocab_size_divisible_by ,
625- * args ,
626612 ** kwargs ,
627613 )
628614
629615 return lm_head_linear
630616
631617 def forward (self , input_ : Tensor ) -> Tuple [Tensor , Tensor ]:
632- assert (
633- input_ .shape [- 1 ] == self .weight .shape [- 1 ]
634- ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}." .format (
635- input_ .shape , self .weight .shape , self .weight .shape [- 1 ]
636- )
637-
638- # Set up backprop all-reduce.
639- input_parallel = input_
640-
641- # Matrix multiply.
642- bias = self .bias if not self .skip_bias_add else None
643- if self .seq_parallel :
644- output_parallel = linear_gather_forward_reducescatter_backward (
645- input_parallel , self .weight , bias , self .process_group , True , self .seq_parallel_dim , self .overlap
646- )
618+ if self .skip_bias_add :
619+ output , _ = super ().forward (input_ )
647620 else :
648- output_parallel = linear_with_async_comm (input_parallel , self .weight , bias , self .process_group , True )
649-
621+ output = super ().forward (input_ )
650622 if self .gather_output :
651- # All-gather across the partitions.
652- output = gather_forward_split_backward (output_parallel , dim = - 1 , process_group = self .process_group )
653623 output = output [..., : self .old_num_embeddings ]
654624 else :
655- output = output_parallel
656625 output = output [..., : self .num_valid_embeddings_local ]
657-
658- if self .skip_bias_add :
659- return output , self .bias
660- else :
661- return output
626+ return output
0 commit comments