Skip to content

Commit 8f2db47

Browse files
committed
fix
fix
1 parent fc52763 commit 8f2db47

File tree

3 files changed

+12
-15
lines changed

3 files changed

+12
-15
lines changed

colossalai/shardformer/layer/linear.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ def __init__(
8282
bias_: Optional[Parameter] = None,
8383
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
8484
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
85+
*args,
86+
**kwargs,
8587
):
86-
super().__init__()
88+
super().__init__(*args, **kwargs)
8789

8890
# Keep input parameters
8991
self.in_features = in_features
@@ -509,7 +511,7 @@ def forward(self, input: Tensor) -> Tensor:
509511
return output
510512

511513

512-
class VocabParallelLMHead1D(PaddingParallelModule, Linear1D_Col):
514+
class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
513515
r"""Linear layer with column parallelism.
514516
515517
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
@@ -570,10 +572,6 @@ def __init__(
570572
new_out_features = out_features + multiple - (out_features % multiple)
571573

572574
super().__init__(
573-
new_num_embeddings=new_out_features,
574-
old_num_embeddings=out_features,
575-
weight_A=weight,
576-
bias_A=bias_,
577575
in_features=in_features,
578576
out_features=new_out_features,
579577
bias=bias,
@@ -583,7 +581,12 @@ def __init__(
583581
bias_=bias_,
584582
*args,
585583
**kwargs,
584+
new_num_embeddings=new_out_features,
585+
old_num_embeddings=out_features,
586+
weight_A=weight,
587+
bias_A=bias_,
586588
)
589+
587590
# get the length of valid embeddings
588591
tp_rank = dist.get_rank(process_group)
589592
partition_size = self.new_num_embeddings // dist.get_world_size(process_group)

colossalai/shardformer/layer/parallel_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525

2626

2727
class ParallelModule(nn.Module, ABC):
28+
def __init__(self, *args, **kwargs) -> None:
29+
super().__init__(*args, **kwargs)
30+
2831
@abstractmethod
2932
def from_native_module(
3033
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None

colossalai/shardformer/policies/bert.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,6 @@ def preprocess(self):
3939
self.tie_weight = self.tie_weight_check()
4040
return self.model
4141

42-
def tie_weight_check(self):
43-
input_embedding = self.model.get_input_embeddings()
44-
output_embedding = self.model.get_output_embeddings()
45-
return (
46-
input_embedding is not None
47-
and output_embedding is not None
48-
and id(input_embedding.weight) == id(output_embedding.weight)
49-
)
50-
5142
def module_policy(self):
5243
from transformers.models.bert.modeling_bert import (
5344
BertEmbeddings,

0 commit comments

Comments
 (0)