@@ -67,8 +67,8 @@ class Linear1D_Col(ParallelModule):
6767
6868    def  __init__ (
6969        self ,
70-         in_features : int ,
71-         out_features : int ,
70+         in_features : int   =   None ,
71+         out_features : int   =   None ,
7272        bias : bool  =  True ,
7373        dtype : torch .dtype  =  None ,
7474        device : torch .device  =  None ,
@@ -82,8 +82,10 @@ def __init__(
8282        bias_ : Optional [Parameter ] =  None ,
8383        weight_initializer : Callable  =  init .kaiming_uniform_ (a = math .sqrt (5 )),
8484        bias_initializer : Callable  =  init .xavier_uniform_ (a = 1 , scale = 1 ),
85+         * args ,
86+         ** kwargs ,
8587    ):
86-         super ().__init__ ()
88+         super ().__init__ (* args ,  ** kwargs )
8789
8890        # Keep input parameters 
8991        self .in_features  =  in_features 
@@ -509,7 +511,7 @@ def forward(self, input: Tensor) -> Tensor:
509511        return  output 
510512
511513
512- class  VocabParallelLMHead1D (PaddingParallelModule ,  Linear1D_Col ):
514+ class  VocabParallelLMHead1D (Linear1D_Col ,  PaddingParallelModule ):
513515    r"""Linear layer with column parallelism. 
514516
515517    The linear layer is defined as :math:`Y = XA + b`. A is parallelized along 
@@ -540,8 +542,8 @@ class VocabParallelLMHead1D(PaddingParallelModule, Linear1D_Col):
540542
541543    def  __init__ (
542544        self ,
543-         in_features : int ,
544-         out_features : int ,
545+         in_features : int   =   None ,
546+         out_features : int   =   None ,
545547        bias : bool  =  True ,
546548        dtype : torch .dtype  =  None ,
547549        device : torch .device  =  None ,
@@ -570,10 +572,6 @@ def __init__(
570572            new_out_features  =  out_features  +  multiple  -  (out_features  %  multiple )
571573
572574        super ().__init__ (
573-             new_num_embeddings = new_out_features ,
574-             old_num_embeddings = out_features ,
575-             weight_A = weight ,
576-             bias_A = bias_ ,
577575            in_features = in_features ,
578576            out_features = new_out_features ,
579577            bias = bias ,
@@ -583,7 +581,12 @@ def __init__(
583581            bias_ = bias_ ,
584582            * args ,
585583            ** kwargs ,
584+             new_num_embeddings = new_out_features ,
585+             old_num_embeddings = out_features ,
586+             weight_A = weight ,
587+             bias_A = bias_ ,
586588        )
589+ 
587590        # get the length of valid embeddings 
588591        tp_rank  =  dist .get_rank (process_group )
589592        partition_size  =  self .new_num_embeddings  //  dist .get_world_size (process_group )
0 commit comments