Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
057246d
padding vocab_size when using pipeline parallellism
flybird11111 Mar 7, 2024
4e7a4b4
fix
flybird11111 Mar 7, 2024
5716bc6
fix
flybird11111 Mar 7, 2024
2bc0539
fix gather output
flybird11111 Mar 10, 2024
c4ca32f
fix
flybird11111 Mar 13, 2024
7b98ddb
fix
flybird11111 Mar 14, 2024
2c39843
fix
flybird11111 Mar 17, 2024
4e6eade
fix resize embedding
flybird11111 Mar 17, 2024
70e491b
revert
flybird11111 Mar 18, 2024
f709c3b
revert
flybird11111 Mar 18, 2024
4e6592b
revert
flybird11111 Mar 18, 2024
c17181f
padding vocab
flybird11111 Mar 21, 2024
c321346
padding vocabe
flybird11111 Mar 21, 2024
d4b097d
fix
flybird11111 Mar 22, 2024
e769fe0
fix
flybird11111 Mar 22, 2024
d2c005c
fxi
flybird11111 Mar 22, 2024
0c0c309
test ci
flybird11111 Mar 22, 2024
318309c
fix
flybird11111 Mar 27, 2024
0409f9d
fix
flybird11111 Mar 28, 2024
73fa546
fix
flybird11111 Mar 28, 2024
255b0b3
fix
flybird11111 Mar 29, 2024
de1dd3c
Update hybrid_parallel_plugin.py
flybird11111 Mar 28, 2024
3f4dd6e
fix
flybird11111 Apr 1, 2024
5a39bec
fix
flybird11111 Apr 2, 2024
bd8e88c
fix
flybird11111 Apr 3, 2024
4b85ac0
resolve super init
flybird11111 Apr 3, 2024
ac7aa1c
resolve comments
flybird11111 Apr 8, 2024
169804c
fix
flybird11111 Apr 8, 2024
f318157
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
4fe3eb4
vocab checkpointio
flybird11111 Apr 10, 2024
3aa204e
padding vocab_size when using pipeline parallellism
flybird11111 Mar 7, 2024
9896666
fix
flybird11111 Mar 7, 2024
e934889
fix
flybird11111 Mar 14, 2024
54f1f8c
fix
flybird11111 Mar 17, 2024
cf4bba9
fix resize embedding
flybird11111 Mar 17, 2024
1c24aa3
revert
flybird11111 Mar 18, 2024
d15ebbe
revert
flybird11111 Mar 18, 2024
24f5f2a
padding vocab
flybird11111 Mar 21, 2024
3d6739f
fix
flybird11111 Mar 22, 2024
2149904
fix
flybird11111 Mar 28, 2024
3813616
fix
flybird11111 Apr 1, 2024
0fed3d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
c9c49d1
fix ci
flybird11111 Apr 9, 2024
90c5520
fix
flybird11111 Apr 9, 2024
6c2ba05
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2024
51b8bcd
fix
flybird11111 Apr 9, 2024
ae964a2
cherry-pick
ver217 Mar 27, 2024
ffd9bc3
revert moe modify
flybird11111 Apr 10, 2024
1114509
Merge branch 'feature/resize-embedding' into padding-vocab
flybird11111 Apr 10, 2024
b570f1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
f08e084
fix
flybird11111 Apr 10, 2024
14a4342
resolve comments
flybird11111 Apr 11, 2024
873e2b3
ptensor
flybird11111 Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.

if optim is None:
return {}
param_info = {"id2shape": {}}

start_index = 0
for group in optim.param_groups:
for param_id, param in enumerate(group["params"], start_index):
Expand Down Expand Up @@ -527,7 +527,7 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
optimizer_params_info = get_param_info(optimizer)
params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
Expand Down Expand Up @@ -558,7 +558,7 @@ def configure(
**self.zero_optim_config,
**self.optim_kwargs,
tp_group=self.tp_group,
optimizer_params_info=optimizer_params_info,
params_info=params_info,
verbose=self.verbose,
)

Expand Down
11 changes: 5 additions & 6 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,7 @@ def get_param_info(optim: Optimizer):

if optim is None:
return {}
param_info = {
"param_groups": [],
"param2id": {},
"id2param": {},
"param2shape": {},
}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
Expand Down Expand Up @@ -947,6 +942,8 @@ class HybridParallelPlugin(PipelinePluginBase):
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.

"""

def __init__(
Expand Down Expand Up @@ -989,6 +986,7 @@ def __init__(
num_model_chunks: int = 1,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
) -> None:
super().__init__()
assert (
Expand Down Expand Up @@ -1095,6 +1093,7 @@ def __init__(
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
)
self.amp_config = dict(
Expand Down
29 changes: 28 additions & 1 deletion colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import get_current_device

from .general_checkpoint_io import GeneralCheckpointIO
Expand All @@ -32,6 +38,7 @@
save_param_groups,
save_state_dict,
save_state_dict_shards,
search_padding_dim,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
Expand Down Expand Up @@ -89,6 +96,8 @@ def _model_sharder(
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
Expand Down Expand Up @@ -231,7 +240,6 @@ def save_sharded_model(
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.

final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
Expand All @@ -251,6 +259,7 @@ def save_sharded_model(
use_safetensors=use_safetensors,
use_pp_format=True,
)

if control_saving:
assert (
self.dp_rank == 0 and self.tp_rank == 0
Expand Down Expand Up @@ -867,6 +876,11 @@ def gather_from_sharded_optimizer_state(
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)

padding_dim = search_padding_dim(v.shape, original_shape)
if padding_dim is not None:
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)

state_[k] = v.detach().clone().to(device)

return state_
Expand Down Expand Up @@ -899,6 +913,19 @@ def shard_from_complete_optimizer_state(
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
global_shape = current_shape
if partition_dim is not None:
# pad embedding params
global_shape = (
*current_shape[:partition_dim],
current_shape[partition_dim] * self.tp_size,
*current_shape[partition_dim + 1 :],
)

padding_dim = search_padding_dim(global_shape, original_shape)
if padding_dim is not None:
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)

if partition_dim is not None:
slice_size = current_shape[partition_dim]
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
Expand Down
9 changes: 9 additions & 0 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
return partition_dim


def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]:
padding_dim = None
for dim, length in enumerate(global_shape):
if length > original_shape[dim]:
padding_dim = dim
break
return padding_dim


# ======================================
# Helper classes and functions for saving shard file
# ======================================
Expand Down
7 changes: 5 additions & 2 deletions colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
Expand All @@ -25,6 +25,9 @@
"FusedRMSNorm",
"FusedLinear1D_Col",
"ParallelModule",
"PaddingEmbedding",
"PaddingLMHead",
"VocabParallelLMHead1D",
"AttnMaskType",
"ColoAttention",
"all_to_all_comm",
Expand Down
111 changes: 95 additions & 16 deletions colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
)

from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset

__all__ = ["Embedding1D", "VocabParallelEmbedding1D"]
__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]


class Embedding1D(ParallelModule):
Expand Down Expand Up @@ -161,7 +161,80 @@ def forward(self, input_: Tensor) -> Tensor:
return output_parallel


class VocabParallelEmbedding1D(ParallelModule):
class PaddingEmbedding(PaddingParallelModule):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
weight: Optional[nn.Parameter] = None,
make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embed_args = args
self.embed_kwargs = kwargs
self.padding_idx = padding_idx
if num_embeddings % make_vocab_size_divisible_by != 0:
self.num_embeddings = (
num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by)
)
# create weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)

super().__init__(self.num_embeddings, num_embeddings, weight)

if weight is None:
self.reset_parameters()

def reset_parameters(self) -> None:
init.normal_(self.weight)
self._fill_padding_idx_with_zero()

def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)

def forward(self, input: Tensor) -> Tensor:
return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> PaddingParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
LazyInitContext.materialize(module)
# get the origin attributes
num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
device = module.weight.device
# create the parallel module
padding_embedding = PaddingEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
device=device,
weight=module.weight,
*args,
**kwargs,
)

return padding_embedding


class VocabParallelEmbedding1D(PaddingParallelModule):
r"""Embedding parallelized in the vocabulary dimension.

Args:
Expand Down Expand Up @@ -201,10 +274,10 @@ def __init__(
process_group: ProcessGroup = None,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embed_args = args
Expand All @@ -214,8 +287,23 @@ def __init__(
tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)

self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings = self.num_embeddings_per_partition
# generate weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)

# calculate new padding size
multiple = make_vocab_size_divisible_by * tensor_parallel_size
if num_embeddings % multiple != 0:
self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple)

# resize vocabulary size
super().__init__(self.num_embeddings, num_embeddings, weight)

# deal with tensor parallelism
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

Expand All @@ -226,13 +314,6 @@ def __init__(
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)

# parameter
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
Expand All @@ -243,7 +324,7 @@ def __init__(
@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
) -> PaddingParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
Expand Down Expand Up @@ -303,11 +384,9 @@ def forward(self, input_: Tensor) -> Tensor:
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0

output_parallel = F.embedding(
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
)

# Mask the output embedding.
embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0
Expand Down
Loading