You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* sequence parallel optimization
* validate sequence parallel in llama (code to be polished)
* shardformer api writing
* integrate sequence parallel in ShardFormer
* fix pp bugs and sp bugs for LlaMa model
* integrating ring-based sequence parallelism into ShardFormer
* [sequence parallelism]: Add fused megatron function
* integrating ring-based sequence parallelism into ShardFormer
---------
Co-authored-by: linsj20 <[email protected]>
* fix bugs when useing sp and flashattention together
* fix operation function name
* support flash attention for ulysses-style sp
* clarify sp process group
* fix compatibility bugs in moe plugin
* fix fused linear bugs
* fix linear layer test
* support gpt model all-to-all sp
* modify shard data dimension (meant to be dim=-1)
* support megtron-style sp and distributed attn for llama model
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* finish sp mode 3 support for gpt
* using all_to_all_single when batch size is 1
* support mode 2 sp in gpt2 (#5)
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* refactor ring implementation
* support mode 2 sp in gpt2
* polish code
* enable distributed attn mask when using sp mode 2 and 3 in llama
* automatically enable flash attn when using sp mode 2 and 3 in llama
* inplace attn mask
* add zero2 support for sequence parallel
* polish code
* fix bugs
* fix gemini checkpoint io
* loose tensor checking atol and rtol
* add comment
* fix llama layernorm grad
* fix zero grad
* fix zero grad
* fix conflict
* update split and gather auto grad func
* sequence parallel: inside text split (#6)
* polish code (part 1)
* polish code (part 2)
* polish code (part 2.5)
* polish code (part 3)
* sequence parallel: inside text split
* miscellaneous minor fixes
* polish code
* fix ulysses style ZeRO
* sequence parallel: inside text split
* miscellaneous minor fixes
* disaggregate sp group and dp group for sp
* fix llama and gpt sp
* polish code
* move ulysses grad sync to ddp (#9)
* remove zero_stage and unbind the grad sync for alltoall sp
* add 2d group creation test
* move ulysses grad sync to ddp
* add 2d group creation test
* remove useless code
* change shard config not to enable sp when enable_all_optimizations
* add sp warnings for several model
* remove useless code
---------
Co-authored-by: linsj20 <[email protected]>
@@ -891,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
891
904
Args:
892
905
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
893
906
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
907
+
sp_size (int): The size of sequence parallelism.
894
908
precision (str, optional): Specifies the precision of parameters during training.
895
909
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
896
910
Defaults to 'fp16'.
@@ -903,6 +917,7 @@ class HybridParallelPlugin(PipelinePluginBase):
903
917
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
904
918
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
905
919
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
920
+
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
906
921
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
907
922
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
908
923
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
@@ -938,13 +953,15 @@ def __init__(
938
953
self,
939
954
tp_size: int,
940
955
pp_size: int,
956
+
sp_size: int=None,
941
957
precision: str="fp16",
942
958
zero_stage: int=0,
943
959
enable_all_optimization: bool=False,
944
960
enable_fused_normalization: bool=False,
945
961
enable_flash_attention: bool=False,
946
962
enable_jit_fused: bool=False,
947
963
enable_sequence_parallelism: bool=False,
964
+
sequence_parallelism_mode: str=None,
948
965
enable_sequence_overlap: bool=False,
949
966
parallel_output: bool=True,
950
967
num_microbatches: Optional[int] =None,
@@ -976,14 +993,41 @@ def __init__(
976
993
super().__init__()
977
994
assert (
978
995
dist.get_world_size() % (tp_size*pp_size) ==0
979
-
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
996
+
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
980
997
981
998
ifenable_sequence_parallelism:
982
-
asserttp_size>1, "Sequence parallelism must be enabled when using tensor parallelism"
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
1007
+
ifsp_size!=1:
1008
+
warnings.warn(
1009
+
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
0 commit comments