Skip to content

Commit dba3e04

Browse files
littskKKZ20
authored andcommitted
[hotfix] Add layer norm gradients all-reduce for sequence parallel (hpcaitech#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (hpcaitech#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (hpcaitech#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <[email protected]> * Hotfix/add grad all reduce for sequence parallel (hpcaitech#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <[email protected]> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <[email protected]>
1 parent 8c58648 commit dba3e04

File tree

30 files changed

+1112
-547
lines changed

30 files changed

+1112
-547
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 341 additions & 31 deletions
Large diffs are not rendered by default.

colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,14 @@ def configure(
338338
if not isinstance(model, ModelWrapper):
339339
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
340340
model = HybridParallelModule(
341-
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
341+
module=model,
342+
precision=self.precision,
343+
shard_config=self.shard_config,
344+
dp_group=self.dp_group,
345+
tp_group=self.tp_group,
346+
use_ddp=use_ddp,
347+
ddp_config=self.ddp_config,
348+
custom_policy=self.custom_policy,
342349
)
343350
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
344351
if self.zero_stage == 0:

colossalai/inference/tensor_parallel/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
218218
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
219219
model_name = model.__class__.__name__
220220
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
221-
221+
222222
model = model.model if self.shard_config.inference_gptq else model
223+
policy = get_autopolicy(model, shard_config=self.shard_config)
223224

224-
policy = get_autopolicy(model, inference_only=True)
225225
self.model, _ = shardformer.optimize(model, policy)
226226

227227
if self.shard_config.inference_gptq:

colossalai/shardformer/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,14 @@ class SubModuleReplacementDescription:
235235

236236

237237
class Policy(ABC):
238+
r"""
239+
The base class for all the policies. For each different model, it should have a different policy class,
240+
like BertPolicy for Bert Model or OPTPolicy for OPT model.
241+
242+
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
243+
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
244+
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
245+
"""
238246

239247
def __init__(self)
240248
self.model = None
@@ -245,6 +253,16 @@ class Policy(ABC):
245253
"""
246254
self.model = model
247255

256+
def set_shard_config(self, shard_config: ShardConfig) -> None:
257+
r"""
258+
Set shard config as an attribute of the Policy object.
259+
Args:
260+
shard_config (:class:`ShardConfig`): The shard config to be perform
261+
"""
262+
self.shard_config = shard_config
263+
264+
self.config_sanity_check()
265+
248266
@abstractmethod
249267
def preprocess(self) -> nn.Module:
250268
"""

colossalai/shardformer/layer/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .embedding import Embedding1D, VocabParallelEmbedding1D
33
from .linear import Linear1D_Col, Linear1D_Row
44
from .loss import cross_entropy_1d
5-
from .normalization import FusedLayerNorm, FusedRMSNorm
5+
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
66
from .parallel_module import ParallelModule
77
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
88

@@ -16,6 +16,9 @@
1616
"DropoutForParallelInput",
1717
"DropoutForReplicatedInput",
1818
"cross_entropy_1d",
19+
"BaseLayerNorm",
20+
"LayerNorm",
21+
"RMSNorm",
1922
"FusedLayerNorm",
2023
"FusedRMSNorm",
2124
"FusedLinear1D_Col",

colossalai/shardformer/layer/normalization.py

Lines changed: 143 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
#!/usr/bin/env python
22
# -*- encoding: utf-8 -*-
33
import warnings
4+
from abc import ABC, abstractmethod
5+
46
import torch.nn as nn
57
from colossalai.lazy import LazyInitContext
68
from ._operation import hook_paramter_in_backward
79

8-
__all__ = ["FusedLayerNorm", "FusedRMSNorm"]
10+
from .utils import SeqParallelUtils
11+
12+
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
913

1014
try:
1115
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
@@ -77,21 +81,128 @@ def forward(self, input):
7781
return output
7882

7983

80-
class FusedLayerNorm:
84+
class BaseLayerNorm(ABC):
85+
@abstractmethod
86+
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
87+
"""
88+
Convert a native PyTorch layer normalization module to a specific layer normalization module,
89+
and optionally mark parameters for gradient aggregation.
90+
91+
Args:
92+
module (nn.Module): The native PyTorch layer normalization module to be converted.
93+
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
94+
95+
Returns:
96+
nn.Module: The specific layer normalization module.
97+
98+
Raises:
99+
AssertionError: If the provided module is not an instance of the supported layer normalization type.
100+
"""
101+
102+
103+
class RMSNorm(BaseLayerNorm):
104+
r"""
105+
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
106+
"""
107+
108+
def __init__(self) -> None:
109+
raise NotImplementedError(
110+
"FusedLayerNorm is not implemented as a physical class. "
111+
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
112+
)
113+
114+
@staticmethod
115+
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
116+
"""
117+
Convert a native RMSNorm module to colossalai layer norm module,
118+
and optionally mark parameters for gradient aggregation.
119+
120+
Args:
121+
module (nn.Module): The native RMSNorm module to be converted.
122+
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
123+
124+
Returns:
125+
nn.Module: The RMSNorm module.
126+
"""
127+
128+
LazyInitContext.materialize(module)
129+
130+
if sp_partial_derived:
131+
# Since gradients are computed using only a subset of the data,
132+
# aggregation of these gradients is necessary during backpropagation.
133+
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
134+
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
135+
136+
return module
137+
138+
139+
class LayerNorm(BaseLayerNorm):
140+
r"""
141+
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
142+
"""
143+
144+
def __init__(self) -> None:
145+
raise NotImplementedError(
146+
"LayerNorm is not implemented as a physical class. "
147+
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
148+
)
149+
150+
@staticmethod
151+
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
152+
r"""
153+
Convert a native pytorch layer norm module to colossalai layer norm module,
154+
and optionally marking parameters for gradient aggregation.
155+
156+
Args:
157+
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
158+
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
159+
160+
Returns:
161+
nn.Module: The LayerNorm module.
162+
163+
Raises:
164+
AssertionError: If the provided module is not an instance of nn.LayerNorm.
165+
"""
166+
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
167+
168+
LazyInitContext.materialize(module)
169+
170+
if sp_partial_derived:
171+
# Since gradients are computed using only a subset of the data,
172+
# aggregation of these gradients is necessary during backpropagation.
173+
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
174+
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
175+
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
176+
177+
return module
178+
179+
180+
class FusedLayerNorm(BaseLayerNorm):
81181
r"""
82182
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
83183
"""
84184

85185
def __init__(self) -> None:
86186
raise NotImplementedError(
87187
"FusedLayerNorm is not implemented as a physical class. "
88-
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
188+
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
89189
)
90190

91191
@staticmethod
92-
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
192+
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
93193
r"""
94-
Convert a native pytorch layer norm module to colossalai layer norm module
194+
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
195+
and optionally marking parameters for gradient aggregation.
196+
197+
Args:
198+
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
199+
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
200+
201+
Returns:
202+
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
203+
204+
Raises:
205+
AssertionError: If the provided module is not an instance of nn.LayerNorm.
95206
"""
96207

97208
LazyInitContext.materialize(module)
@@ -120,21 +231,39 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
120231
layernorm.weight = module.weight
121232
layernorm.bias = module.bias
122233

234+
if sp_partial_derived:
235+
# Since gradients are computed using only a subset of the data,
236+
# aggregation of these gradients is necessary during backpropagation.
237+
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
238+
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
239+
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
240+
123241
return layernorm
124242

125243

126-
class FusedRMSNorm:
244+
class FusedRMSNorm(BaseLayerNorm):
127245
"""
128246
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
129247
"""
130248
def __init__(self) -> None:
131249
raise NotImplementedError(
132250
"FusedRMSNorm is not implemented as a physical class. "
133-
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
251+
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
134252
)
135253

136254
@staticmethod
137-
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
255+
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
256+
r"""
257+
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
258+
and optionally marking parameters for gradient aggregation.
259+
260+
Args:
261+
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
262+
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
263+
264+
Returns:
265+
nn.Module: FusedRMSNorm module.
266+
"""
138267
LazyInitContext.materialize(module)
139268
# to check if it is huggingface LlamaRMSNorm
140269
if module.__class__.__name__ == "LlamaRMSNorm":
@@ -151,4 +280,10 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
151280

152281
rmsnorm.weight = module.weight
153282

283+
if sp_partial_derived:
284+
# Since gradients are computed using only a subset of the data,
285+
# aggregation of these gradients is necessary during backpropagation.
286+
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
287+
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
288+
154289
return rmsnorm

colossalai/shardformer/layer/utils.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,82 @@
11
from contextlib import contextmanager
2+
from typing import List
23

34
import torch
45
import torch.distributed as dist
5-
from torch.distributed import ProcessGroup
6+
from torch import nn
7+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
8+
from torch.distributed import ProcessGroup, get_world_size
9+
10+
11+
class SeqParallelUtils:
12+
@staticmethod
13+
def marked_as_sp_partial_derived_param(param):
14+
"""
15+
Mark a parameter as partially derived in sequence parallelism.
16+
17+
Args:
18+
param: The parameter to mark as partially derived.
19+
"""
20+
setattr(param, "partial_derived", True)
21+
22+
@staticmethod
23+
def is_sp_partial_derived_param(param):
24+
"""
25+
Check if a parameter is marked as partially derived in sequence parallelism.
26+
27+
Args:
28+
param: The parameter to check.
29+
30+
Returns:
31+
bool: True if the parameter is marked as partially derived, False otherwise.
32+
"""
33+
return getattr(param, "partial_derived", False)
34+
35+
@staticmethod
36+
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
37+
"""
38+
Allreduce partial derived gradients across the specified process group.
39+
40+
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
41+
42+
Args:
43+
tp_group (ProcessGroup): The process group for gradient synchronization.
44+
model (nn.Module): The model from which gradients will be synchronized.
45+
grads (List[torch.Tensor]): The list of gradients to be synchronized.
46+
47+
Raises:
48+
AssertionError: If both `model` and `grads` are provided or neither is provided.
49+
"""
50+
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
51+
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
52+
53+
# Get the size of the process group, which determines whether synchronization is needed.
54+
tp_size = get_world_size(tp_group) if tp_group is not None else 1
55+
56+
if tp_size == 1:
57+
# If the process group size is 1, no synchronization is required.
58+
return
59+
60+
if model is not None:
61+
# If `model` is provided, extract partial derived gradients from the model's parameters.
62+
grads = []
63+
for p in model.parameters():
64+
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
65+
grads.append(p.grad.data)
66+
67+
# Flatten and reduce the gradients using the specified process group.
68+
coalesced = _flatten_dense_tensors(grads)
69+
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
70+
71+
# Unflatten the synchronized gradients and update the model's gradients.
72+
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
73+
buf.copy_(synced)
74+
else:
75+
# If `grads` are provided explicitly, synchronize those gradients directly.
76+
coalesced = _flatten_dense_tensors(grads)
77+
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
78+
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
79+
buf.copy_(synced)
680

781

882
class Randomizer:

0 commit comments

Comments
 (0)