@@ -82,8 +82,10 @@ def __init__(
8282 bias_ : Optional [Parameter ] = None ,
8383 weight_initializer : Callable = init .kaiming_uniform_ (a = math .sqrt (5 )),
8484 bias_initializer : Callable = init .xavier_uniform_ (a = 1 , scale = 1 ),
85+ * args ,
86+ ** kwargs ,
8587 ):
86- super ().__init__ ()
88+ super ().__init__ (* args , ** kwargs )
8789
8890 # Keep input parameters
8991 self .in_features = in_features
@@ -509,7 +511,7 @@ def forward(self, input: Tensor) -> Tensor:
509511 return output
510512
511513
512- class VocabParallelLMHead1D (PaddingParallelModule , Linear1D_Col ):
514+ class VocabParallelLMHead1D (Linear1D_Col , PaddingParallelModule ):
513515 r"""Linear layer with column parallelism.
514516
515517 The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
@@ -570,10 +572,6 @@ def __init__(
570572 new_out_features = out_features + multiple - (out_features % multiple )
571573
572574 super ().__init__ (
573- new_num_embeddings = new_out_features ,
574- old_num_embeddings = out_features ,
575- weight_A = weight ,
576- bias_A = bias_ ,
577575 in_features = in_features ,
578576 out_features = new_out_features ,
579577 bias = bias ,
@@ -583,7 +581,12 @@ def __init__(
583581 bias_ = bias_ ,
584582 * args ,
585583 ** kwargs ,
584+ new_num_embeddings = new_out_features ,
585+ old_num_embeddings = out_features ,
586+ weight_A = weight ,
587+ bias_A = bias_ ,
586588 )
589+
587590 # get the length of valid embeddings
588591 tp_rank = dist .get_rank (process_group )
589592 partition_size = self .new_num_embeddings // dist .get_world_size (process_group )
0 commit comments