Skip to content

Commit 8e412a5

Browse files
KKZ20linsj20
andauthored
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <[email protected]> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <[email protected]>
1 parent 7e0ec5a commit 8e412a5

33 files changed

+1627
-253
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434

3535
from .pp_plugin_base import PipelinePluginBase
3636

37-
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
37+
DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
38+
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
3839

3940
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
4041

@@ -53,6 +54,7 @@ def __init__(
5354
shard_config: ShardConfig,
5455
dp_group: ProcessGroup,
5556
tp_group: ProcessGroup,
57+
sp_group: ProcessGroup,
5658
use_ddp: bool,
5759
ddp_config: dict,
5860
custom_policy: Policy,
@@ -61,6 +63,7 @@ def __init__(
6163
self.shard_config = shard_config
6264
self.dp_group = dp_group
6365
self.tp_group = tp_group
66+
self.sp_group = sp_group
6467
self.use_dpp = use_ddp
6568
self.require_grad_sync = True
6669

@@ -168,13 +171,24 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
168171
Returns:
169172
None
170173
"""
171-
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
174+
175+
if self.shard_config.enable_sequence_parallelism:
176+
if self.shard_config.sequence_parallelism_mode == "all_to_all":
177+
return
178+
179+
if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
180+
# If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized
181+
# across the tensor parallelism group.
182+
group = self.tp_group
183+
else:
184+
raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}")
185+
172186
if grads is not None:
173187
# Synchronize provided gradient tensors across the tensor parallelism group.
174-
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
188+
SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads)
175189
else:
176190
# Synchronize gradients from the model across the tensor parallelism group.
177-
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
191+
SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module)
178192

179193
def forward(self, *args, **kwargs):
180194
if self.convert_fn is not None:
@@ -727,10 +741,9 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
727741
# Get all working gradients and gradients to be synchronized.
728742
all_working_grads = _get_all_working_grads()
729743
grads_to_sync = _get_grads_to_sync(all_working_grads)
730-
731744
if self.require_grad_sync and grads_to_sync is not None:
732745
# Synchronize sequence parallelism gradients if required.
733-
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
746+
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
734747
else:
735748
return
736749

@@ -891,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
891904
Args:
892905
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
893906
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
907+
sp_size (int): The size of sequence parallelism.
894908
precision (str, optional): Specifies the precision of parameters during training.
895909
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
896910
Defaults to 'fp16'.
@@ -903,6 +917,7 @@ class HybridParallelPlugin(PipelinePluginBase):
903917
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
904918
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
905919
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
920+
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
906921
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
907922
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
908923
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
@@ -938,13 +953,15 @@ def __init__(
938953
self,
939954
tp_size: int,
940955
pp_size: int,
956+
sp_size: int = None,
941957
precision: str = "fp16",
942958
zero_stage: int = 0,
943959
enable_all_optimization: bool = False,
944960
enable_fused_normalization: bool = False,
945961
enable_flash_attention: bool = False,
946962
enable_jit_fused: bool = False,
947963
enable_sequence_parallelism: bool = False,
964+
sequence_parallelism_mode: str = None,
948965
enable_sequence_overlap: bool = False,
949966
parallel_output: bool = True,
950967
num_microbatches: Optional[int] = None,
@@ -976,14 +993,41 @@ def __init__(
976993
super().__init__()
977994
assert (
978995
dist.get_world_size() % (tp_size * pp_size) == 0
979-
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
996+
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
980997

981998
if enable_sequence_parallelism:
982-
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
999+
self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1"
1000+
assert (
1001+
self.sequence_parallelism_mode in SUPPORT_SP_MODE
1002+
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
1003+
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
1004+
assert (
1005+
tp_size > 1
1006+
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
1007+
if sp_size != 1:
1008+
warnings.warn(
1009+
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
1010+
)
1011+
self.sp_size = 1
1012+
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
1013+
elif self.sequence_parallelism_mode in ["all_to_all"]:
1014+
assert (
1015+
tp_size == 1
1016+
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism"
1017+
assert (
1018+
pp_size == 1
1019+
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism"
1020+
self.sp_size = dist.get_world_size() if sp_size is None else sp_size
1021+
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size)
1022+
else:
1023+
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
1024+
assert (
1025+
sp_size == 1 or sp_size is None
1026+
), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True"
1027+
self.sp_size = 1
9831028

9841029
self.tp_size = tp_size
9851030
self.pp_size = pp_size
986-
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
9871031
self.precision = precision
9881032
self.zero_stage = zero_stage
9891033
self.cpu_offload = cpu_offload
@@ -992,7 +1036,7 @@ def __init__(
9921036
self.enable_flash_attention = enable_flash_attention
9931037
self.enable_jit_fused = enable_jit_fused
9941038
self.enable_sequence_parallelism = enable_sequence_parallelism
995-
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
1039+
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
9961040
self.stage_manager = None
9971041
self.schedule = None
9981042
self.custom_policy = custom_policy
@@ -1033,16 +1077,22 @@ def __init__(
10331077
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
10341078
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
10351079
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
1080+
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
1081+
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
1082+
else:
1083+
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
10361084

10371085
self.shard_config = ShardConfig(
10381086
tensor_parallel_process_group=self.tp_group,
1087+
sequence_parallel_process_group=self.sp_group,
10391088
pipeline_stage_manager=self.stage_manager,
10401089
enable_tensor_parallelism=self.tp_size > 1,
10411090
enable_all_optimization=self.enable_all_optimization,
10421091
enable_fused_normalization=self.enable_fused_normalization,
10431092
enable_flash_attention=self.enable_flash_attention,
10441093
enable_jit_fused=self.enable_jit_fused,
10451094
enable_sequence_parallelism=enable_sequence_parallelism,
1095+
sequence_parallelism_mode=sequence_parallelism_mode,
10461096
enable_sequence_overlap=enable_sequence_overlap,
10471097
parallel_output=parallel_output,
10481098
gradient_checkpoint_config=gradient_checkpoint_config,
@@ -1113,13 +1163,23 @@ def configure(
11131163
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
11141164
param_info = get_param_info(optimizer)
11151165
if not isinstance(model, ModelWrapper):
1116-
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
1166+
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
1167+
self.dp_size == 1
1168+
and self.pp_size == 1
1169+
and self.enable_sequence_parallelism
1170+
and self.sequence_parallelism_mode == "all_to_all"
1171+
)
1172+
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
1173+
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
1174+
else:
1175+
dp_group = self.dp_group
11171176
model = HybridParallelModule(
11181177
model,
11191178
precision=self.precision,
11201179
shard_config=self.shard_config,
1121-
dp_group=self.dp_group,
1180+
dp_group=dp_group,
11221181
tp_group=self.tp_group,
1182+
sp_group=self.sp_group,
11231183
use_ddp=use_ddp,
11241184
ddp_config=self.ddp_config,
11251185
custom_policy=self.custom_policy,
@@ -1149,7 +1209,8 @@ def configure(
11491209
tp_process_group=self.tp_group,
11501210
)
11511211
else:
1152-
if self.dp_size == 1:
1212+
zero_dp_size = dist.get_world_size(dp_group)
1213+
if zero_dp_size == 1:
11531214
warnings.warn(
11541215
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
11551216
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
@@ -1161,7 +1222,7 @@ def configure(
11611222
model,
11621223
use_pipeline=self.enable_pipeline_parallelism,
11631224
param_info=param_info,
1164-
dp_process_group=self.dp_group,
1225+
dp_process_group=dp_group,
11651226
tp_process_group=self.tp_group,
11661227
pp_process_group=self.pp_group,
11671228
verbose=True,

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ def __init__(
254254
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
255255
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
256256
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
257+
# TODO: Currently moe only support partially sequence parallel
258+
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
259+
257260
self.shard_config = ShardConfig(
258261
tensor_parallel_process_group=self.tp_group,
259262
pipeline_stage_manager=self.stage_manager,
@@ -365,6 +368,7 @@ def configure(
365368
shard_config=self.shard_config,
366369
dp_group=self.dp_group,
367370
tp_group=self.tp_group,
371+
sp_group=self.sp_group,
368372
use_ddp=use_ddp,
369373
ddp_config=self.ddp_config,
370374
custom_policy=self.custom_policy,

colossalai/cluster/process_group_mesh.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
161161

162162
@staticmethod
163163
def get_coords_along_axis(
164-
base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
164+
base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]]
165165
) -> List[Tuple[int, ...]]:
166166
"""Get coordinates along the given axis.
167167
@@ -173,13 +173,28 @@ def get_coords_along_axis(
173173
Returns:
174174
List[Tuple[int, ...]]: Coordinates along the axis.
175175
"""
176-
coords_in_group = []
177-
for idx in indices_at_axis:
178-
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
176+
if isinstance(axis, int):
177+
axis = [axis,]
178+
assert isinstance(indices_at_axis[0], int)
179+
indices_at_axis = [indices_at_axis,]
180+
181+
def add_index(base_coord, axis, indices_at_axis):
182+
coords_in_group = []
183+
for idx in indices_at_axis:
184+
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
185+
return coords_in_group
186+
187+
coords_in_group = [base_coord]
188+
for ax, indices_at_ax in zip(axis, indices_at_axis):
189+
new_coords_in_group = []
190+
for coords in coords_in_group:
191+
new_coords_in_group += add_index(coords, ax, indices_at_ax)
192+
coords_in_group = new_coords_in_group
193+
179194
return coords_in_group
180195

181196
def create_group_along_axis(
182-
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
197+
self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None
183198
) -> ProcessGroup:
184199
"""Create all process groups along the given axis, and return the one which the current process belongs to.
185200
@@ -191,10 +206,17 @@ def create_group_along_axis(
191206
Returns:
192207
ProcessGroup: The process group along the given axis which the current process belongs to.
193208
"""
194-
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
209+
if isinstance(axis, int):
210+
axis = [axis,]
211+
if indices_at_axis is not None:
212+
assert isinstance(indices_at_axis[0], int)
213+
indices_at_axis = [indices_at_axis,]
214+
215+
indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
195216
reduced_shape = list(self._shape)
196217
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
197-
reduced_shape[axis] = 1
218+
for ax in axis:
219+
reduced_shape[ax] = 1
198220
target_group = None
199221
# use Cartesian product to generate all combinations of coordinates
200222
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
@@ -225,4 +247,3 @@ def get_group_along_axis(
225247
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
226248
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
227249
return self._ranks_to_group[ranks_in_group]
228-

colossalai/shardformer/layer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .attn import AttnMaskType, ColoAttention
2+
from ._operation import all_to_all_comm
23
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
34
from .embedding import Embedding1D, VocabParallelEmbedding1D
45
from .linear import Linear1D_Col, Linear1D_Row
@@ -26,4 +27,5 @@
2627
"ParallelModule",
2728
"AttnMaskType",
2829
"ColoAttention",
30+
"all_to_all_comm",
2931
]

0 commit comments

Comments
 (0)