Skip to content

Commit ed4a808

Browse files
committed
resolve super init
resolve super init resolve super init resolve super init
1 parent a0ca44e commit ed4a808

File tree

3 files changed

+19
-60
lines changed

3 files changed

+19
-60
lines changed

colossalai/shardformer/layer/embedding.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def from_native_module(
220220
embedding_dim = module.embedding_dim
221221
padding_idx = module.padding_idx
222222
device = module.weight.device
223-
make_vocab_size_divisible_by = kwargs.pop("make_vocab_size_divisible_by", 64)
224223

225224
# create the parallel module
226225
padding_embedding = PaddingEmbedding(
@@ -229,7 +228,6 @@ def from_native_module(
229228
padding_idx=padding_idx,
230229
device=device,
231230
weight=module.weight,
232-
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
233231
*args,
234232
**kwargs,
235233
)
@@ -343,8 +341,6 @@ def from_native_module(
343341
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
344342
process_group = process_group[0]
345343

346-
make_vocab_size_divisible_by = kwargs.pop("make_vocab_size_divisible_by", 64)
347-
348344
# create the parallel module
349345
vocab_embedding_1d = VocabParallelEmbedding1D(
350346
num_embeddings=num_embeddings,
@@ -353,7 +349,6 @@ def from_native_module(
353349
device=device,
354350
process_group=process_group,
355351
weight=module.weight,
356-
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
357352
*args,
358353
**kwargs,
359354
)

colossalai/shardformer/layer/linear.py

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,9 @@ def __init__(
8282
bias_: Optional[Parameter] = None,
8383
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
8484
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
85-
*args,
8685
**kwargs,
8786
):
88-
super().__init__(*args, **kwargs)
87+
super().__init__(weight=weight, bias_=bias_, **kwargs)
8988

9089
# Keep input parameters
9190
self.in_features = in_features
@@ -141,7 +140,7 @@ def __init__(
141140

142141
@staticmethod
143142
def from_native_module(
144-
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
143+
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
145144
) -> ParallelModule:
146145
r"""
147146
Convert a native PyTorch linear layer to a parallelized linear layer.
@@ -174,7 +173,6 @@ def from_native_module(
174173
process_group=process_group,
175174
weight=module.weight,
176175
bias_=module.bias,
177-
*args,
178176
**kwargs,
179177
)
180178

@@ -316,7 +314,7 @@ def __init__(
316314

317315
@staticmethod
318316
def from_native_module(
319-
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
317+
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
320318
) -> ParallelModule:
321319
r"""
322320
Convert a native PyTorch linear layer to a parallelized linear layer.
@@ -350,7 +348,6 @@ def from_native_module(
350348
process_group=process_group,
351349
weight=module.weight,
352350
bias_=module.bias,
353-
*args,
354351
**kwargs,
355352
)
356353

@@ -477,7 +474,7 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None:
477474

478475
@staticmethod
479476
def from_native_module(
480-
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
477+
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
481478
) -> PaddingParallelModule:
482479
r"""
483480
Convert a native PyTorch linear layer to a parallelized linear layer.
@@ -489,7 +486,6 @@ def from_native_module(
489486
bias = module.bias is not None
490487
device = module.weight.device
491488
# ensure only one process group is passed
492-
make_vocab_size_divisible_by = kwargs.pop("make_vocab_size_divisible_by", 64)
493489

494490
lm_head_linear = PaddingLMHead(
495491
in_features=in_features,
@@ -498,8 +494,6 @@ def from_native_module(
498494
device=device,
499495
weight=module.weight,
500496
bias_=module.bias,
501-
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
502-
*args,
503497
**kwargs,
504498
)
505499

@@ -551,7 +545,6 @@ def __init__(
551545
weight: Optional[Parameter] = None,
552546
bias_: Optional[Parameter] = None,
553547
make_vocab_size_divisible_by: int = 64,
554-
*args,
555548
**kwargs,
556549
):
557550
# create weight and bias
@@ -579,12 +572,9 @@ def __init__(
579572
process_group=process_group,
580573
weight=weight,
581574
bias_=bias_,
582-
*args,
583575
**kwargs,
584576
new_num_embeddings=new_out_features,
585577
old_num_embeddings=out_features,
586-
weight_A=weight,
587-
bias_A=bias_,
588578
)
589579

590580
# get the length of valid embeddings
@@ -599,7 +589,7 @@ def __init__(
599589

600590
@staticmethod
601591
def from_native_module(
602-
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
592+
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
603593
) -> PaddingParallelModule:
604594
r"""
605595
Convert a native PyTorch linear layer to a parallelized linear layer.
@@ -611,8 +601,6 @@ def from_native_module(
611601
bias = module.bias is not None
612602
device = module.weight.device
613603

614-
make_vocab_size_divisible_by = kwargs.pop("make_vocab_size_divisible_by", 64)
615-
616604
lm_head_linear = VocabParallelLMHead1D(
617605
in_features=in_features,
618606
out_features=out_features,
@@ -621,41 +609,18 @@ def from_native_module(
621609
process_group=process_group,
622610
weight=module.weight,
623611
bias_=module.bias,
624-
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
625-
*args,
626612
**kwargs,
627613
)
628614

629615
return lm_head_linear
630616

631617
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
632-
assert (
633-
input_.shape[-1] == self.weight.shape[-1]
634-
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
635-
input_.shape, self.weight.shape, self.weight.shape[-1]
636-
)
637-
638-
# Set up backprop all-reduce.
639-
input_parallel = input_
640-
641-
# Matrix multiply.
642-
bias = self.bias if not self.skip_bias_add else None
643-
if self.seq_parallel:
644-
output_parallel = linear_gather_forward_reducescatter_backward(
645-
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap
646-
)
618+
if self.skip_bias_add:
619+
output, _ = super().forward(input_)
647620
else:
648-
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
649-
621+
output = super().forward(input_)
650622
if self.gather_output:
651-
# All-gather across the partitions.
652-
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
653623
output = output[..., : self.old_num_embeddings]
654624
else:
655-
output = output_parallel
656625
output = output[..., : self.num_valid_embeddings_local]
657-
658-
if self.skip_bias_add:
659-
return output, self.bias
660-
else:
661-
return output
626+
return output

colossalai/shardformer/layer/parallel_module.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626

2727
class ParallelModule(nn.Module, ABC):
28-
def __init__(self, *args, **kwargs) -> None:
29-
super().__init__(*args, **kwargs)
28+
def __init__(self, **kwargs):
29+
super().__init__()
3030

3131
@abstractmethod
3232
def from_native_module(
@@ -176,21 +176,20 @@ def _load_from_state_dict(
176176
unexpected_keys.append(key)
177177

178178

179-
class PaddingParallelModule(nn.Module, ABC):
179+
class PaddingParallelModule(ParallelModule):
180180
def __init__(
181181
self,
182-
new_num_embeddings: int = None,
183-
old_num_embeddings: int = None,
184-
weight_A: Optional[nn.Parameter] = None,
185-
bias_A: Optional[nn.Parameter] = None,
186-
*args,
182+
new_num_embeddings: int,
183+
old_num_embeddings: int,
184+
weight: Optional[nn.Parameter],
185+
bias_: Optional[nn.Parameter] = None,
187186
**kwargs,
188187
) -> None:
189-
super().__init__(*args, **kwargs)
188+
super().__init__(**kwargs)
190189
self.new_num_embeddings = new_num_embeddings
191190
self.old_num_embeddings = old_num_embeddings
192-
self.weight = weight_A
193-
self.bias = bias_A
191+
self.weight = weight
192+
self.bias = bias_
194193

195194
if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings):
196195
self.resize_embedding_weight()

0 commit comments

Comments
 (0)