Skip to content

Commit a0ca44e

Browse files
committed
fix
1 parent 8f2db47 commit a0ca44e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

colossalai/shardformer/layer/linear.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,11 +591,11 @@ def __init__(
591591
tp_rank = dist.get_rank(process_group)
592592
partition_size = self.new_num_embeddings // dist.get_world_size(process_group)
593593
if self.old_num_embeddings >= (tp_rank + 1) * partition_size:
594-
self.num_valid_embeddings = partition_size
594+
self.num_valid_embeddings_local = partition_size
595595
elif self.old_num_embeddings >= tp_rank * partition_size:
596-
self.num_valid_embeddings = self.old_num_embeddings - tp_rank * partition_size
596+
self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size
597597
else:
598-
self.num_valid_embeddings = 0
598+
self.num_valid_embeddings_local = 0
599599

600600
@staticmethod
601601
def from_native_module(
@@ -653,7 +653,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
653653
output = output[..., : self.old_num_embeddings]
654654
else:
655655
output = output_parallel
656-
output = output[..., : self.num_valid_embeddings]
656+
output = output[..., : self.num_valid_embeddings_local]
657657

658658
if self.skip_bias_add:
659659
return output, self.bias

0 commit comments

Comments
 (0)