Skip to content

Commit 3978b37

Browse files
authored
[hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)
* Add layer norm gradients all-reduce for sequence parallel. * Modify docs and polish code * Polish code * skip pipeline inference test
1 parent 7768afb commit 3978b37

File tree

28 files changed

+1255
-567
lines changed

28 files changed

+1255
-567
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 341 additions & 31 deletions
Large diffs are not rendered by default.

colossalai/shardformer/README.md

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,15 +235,28 @@ class SubModuleReplacementDescription:
235235

236236

237237
class Policy(ABC):
238+
r"""
239+
The base class for all the policies. For each different model, it should have a different policy class,
240+
like BertPolicy for Bert Model or OPTPolicy for OPT model.
238241
239-
def __init__(self)
240-
self.model = None
242+
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
243+
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
244+
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
245+
"""
241246

242-
def set_model(self, model: nn.Module) -> None:
247+
def __init__(self, model: Optional[Module] = None, shard_config: Optional[ShardConfig] = None) -> None:
243248
"""
244-
Set model as an attribute of the Policy object so that we can access the model's attributes.
249+
Initialize a Policy object.
250+
251+
This method sets the model and shard configuration for the policy and performs a configuration sanity check.
252+
253+
Args:
254+
model (Optional[Module]): The model to be used with this policy.
255+
shard_config (Optional[ShardConfig]): The sharding configuration for the policy.
245256
"""
246-
self.model = model
257+
self.model: Optional[Module] = model
258+
self.shard_config: Optional[ShardConfig] = shard_config
259+
self.config_sanity_check()
247260

248261
@abstractmethod
249262
def preprocess(self) -> nn.Module:

colossalai/shardformer/layer/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .embedding import Embedding1D, VocabParallelEmbedding1D
33
from .linear import Linear1D_Col, Linear1D_Row
44
from .loss import cross_entropy_1d
5-
from .normalization import FusedLayerNorm, FusedRMSNorm
5+
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
66
from .parallel_module import ParallelModule
77
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
88

@@ -16,6 +16,9 @@
1616
"DropoutForParallelInput",
1717
"DropoutForReplicatedInput",
1818
"cross_entropy_1d",
19+
"BaseLayerNorm",
20+
"LayerNorm",
21+
"RMSNorm",
1922
"FusedLayerNorm",
2023
"FusedRMSNorm",
2124
"FusedLinear1D_Col",

colossalai/shardformer/layer/normalization.py

Lines changed: 146 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#!/usr/bin/env python
22
# -*- encoding: utf-8 -*-
3+
from abc import ABC, abstractmethod
34

45
import torch.nn as nn
56

67
from colossalai.lazy import LazyInitContext
78

8-
__all__ = ["FusedLayerNorm", "FusedRMSNorm"]
9+
from .utils import SeqParallelUtils
10+
11+
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
912

1013
FAST_LAYERNORM_SUPPORTED_SIZE = [
1114
1024,
@@ -35,23 +38,133 @@
3538
]
3639

3740

38-
class FusedLayerNorm:
41+
class BaseLayerNorm(ABC):
42+
@abstractmethod
43+
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
44+
"""
45+
Convert a native PyTorch layer normalization module to a specific layer normalization module,
46+
and optionally mark parameters for gradient aggregation.
47+
48+
Args:
49+
module (nn.Module): The native PyTorch layer normalization module to be converted.
50+
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
51+
52+
Returns:
53+
nn.Module: The specific layer normalization module.
54+
55+
Raises:
56+
AssertionError: If the provided module is not an instance of the supported layer normalization type.
57+
"""
58+
59+
60+
class RMSNorm(BaseLayerNorm):
61+
r"""
62+
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
63+
"""
64+
65+
def __init__(self) -> None:
66+
raise NotImplementedError(
67+
"FusedLayerNorm is not implemented as a physical class. "
68+
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
69+
)
70+
71+
@staticmethod
72+
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
73+
"""
74+
Convert a native RMSNorm module to colossalai layer norm module,
75+
and optionally mark parameters for gradient aggregation.
76+
77+
Args:
78+
module (nn.Module): The native RMSNorm module to be converted.
79+
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
80+
81+
Returns:
82+
nn.Module: The RMSNorm module.
83+
"""
84+
85+
LazyInitContext.materialize(module)
86+
87+
if sp_partial_derived:
88+
# Since gradients are computed using only a subset of the data,
89+
# aggregation of these gradients is necessary during backpropagation.
90+
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
91+
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
92+
93+
return module
94+
95+
96+
class LayerNorm(BaseLayerNorm):
97+
r"""
98+
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
99+
"""
100+
101+
def __init__(self) -> None:
102+
raise NotImplementedError(
103+
"LayerNorm is not implemented as a physical class. "
104+
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
105+
)
106+
107+
@staticmethod
108+
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
109+
r"""
110+
Convert a native pytorch layer norm module to colossalai layer norm module,
111+
and optionally marking parameters for gradient aggregation.
112+
113+
Args:
114+
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
115+
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
116+
117+
Returns:
118+
nn.Module: The LayerNorm module.
119+
120+
Raises:
121+
AssertionError: If the provided module is not an instance of nn.LayerNorm.
122+
"""
123+
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
124+
125+
LazyInitContext.materialize(module)
126+
127+
if sp_partial_derived:
128+
# Since gradients are computed using only a subset of the data,
129+
# aggregation of these gradients is necessary during backpropagation.
130+
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
131+
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
132+
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
133+
134+
return module
135+
136+
137+
class FusedLayerNorm(BaseLayerNorm):
39138
r"""
40139
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
41140
"""
42141

43142
def __init__(self) -> None:
44143
raise NotImplementedError(
45144
"FusedLayerNorm is not implemented as a physical class. "
46-
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
145+
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
47146
)
48147

49148
@staticmethod
50-
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
149+
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
51150
r"""
52-
Convert a native pytorch layer norm module to colossalai layer norm module
151+
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
152+
and optionally marking parameters for gradient aggregation.
153+
154+
Args:
155+
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
156+
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
157+
158+
Returns:
159+
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
160+
161+
Raises:
162+
AssertionError: If the provided module is not an instance of nn.LayerNorm.
53163
"""
54164
# check if apex is installed
165+
166+
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
167+
55168
try:
56169
pass
57170
except ImportError:
@@ -85,22 +198,41 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
85198

86199
layernorm.weight = module.weight
87200
layernorm.bias = module.bias
201+
202+
if sp_partial_derived:
203+
# Since gradients are computed using only a subset of the data,
204+
# aggregation of these gradients is necessary during backpropagation.
205+
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
206+
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
207+
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
208+
88209
return layernorm
89210

90211

91-
class FusedRMSNorm:
212+
class FusedRMSNorm(BaseLayerNorm):
92213
"""
93214
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
94215
"""
95216

96217
def __init__(self) -> None:
97218
raise NotImplementedError(
98219
"FusedRMSNorm is not implemented as a physical class. "
99-
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
220+
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
100221
)
101222

102223
@staticmethod
103-
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
224+
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
225+
r"""
226+
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
227+
and optionally marking parameters for gradient aggregation.
228+
229+
Args:
230+
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
231+
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
232+
233+
Returns:
234+
nn.Module: FusedRMSNorm module.
235+
"""
104236
try:
105237
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
106238
except ImportError:
@@ -124,4 +256,10 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
124256

125257
rmsnorm.weight = module.weight
126258

259+
if sp_partial_derived:
260+
# Since gradients are computed using only a subset of the data,
261+
# aggregation of these gradients is necessary during backpropagation.
262+
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
263+
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
264+
127265
return rmsnorm

colossalai/shardformer/layer/utils.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,82 @@
11
from contextlib import contextmanager
2+
from typing import List
23

34
import torch
45
import torch.distributed as dist
5-
from torch.distributed import ProcessGroup
6+
from torch import nn
7+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
8+
from torch.distributed import ProcessGroup, get_world_size
9+
10+
11+
class SeqParallelUtils:
12+
@staticmethod
13+
def marked_as_sp_partial_derived_param(param):
14+
"""
15+
Mark a parameter as partially derived in sequence parallelism.
16+
17+
Args:
18+
param: The parameter to mark as partially derived.
19+
"""
20+
setattr(param, "partial_derived", True)
21+
22+
@staticmethod
23+
def is_sp_partial_derived_param(param):
24+
"""
25+
Check if a parameter is marked as partially derived in sequence parallelism.
26+
27+
Args:
28+
param: The parameter to check.
29+
30+
Returns:
31+
bool: True if the parameter is marked as partially derived, False otherwise.
32+
"""
33+
return getattr(param, "partial_derived", False)
34+
35+
@staticmethod
36+
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
37+
"""
38+
Allreduce partial derived gradients across the specified process group.
39+
40+
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
41+
42+
Args:
43+
tp_group (ProcessGroup): The process group for gradient synchronization.
44+
model (nn.Module): The model from which gradients will be synchronized.
45+
grads (List[torch.Tensor]): The list of gradients to be synchronized.
46+
47+
Raises:
48+
AssertionError: If both `model` and `grads` are provided or neither is provided.
49+
"""
50+
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
51+
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
52+
53+
# Get the size of the process group, which determines whether synchronization is needed.
54+
tp_size = get_world_size(tp_group) if tp_group is not None else 1
55+
56+
if tp_size == 1:
57+
# If the process group size is 1, no synchronization is required.
58+
return
59+
60+
if model is not None:
61+
# If `model` is provided, extract partial derived gradients from the model's parameters.
62+
grads = []
63+
for p in model.parameters():
64+
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
65+
grads.append(p.grad.data)
66+
67+
# Flatten and reduce the gradients using the specified process group.
68+
coalesced = _flatten_dense_tensors(grads)
69+
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
70+
71+
# Unflatten the synchronized gradients and update the model's gradients.
72+
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
73+
buf.copy_(synced)
74+
else:
75+
# If `grads` are provided explicitly, synchronize those gradients directly.
76+
coalesced = _flatten_dense_tensors(grads)
77+
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
78+
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
79+
buf.copy_(synced)
680

781

882
class Randomizer:

colossalai/shardformer/policies/auto_policy.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch.nn as nn
66

7+
from ..shard.shard_config import ShardConfig
78
from .base_policy import Policy
89

910
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
@@ -197,7 +198,7 @@ def _fullname(obj):
197198
return module + "." + klass.__qualname__
198199

199200

200-
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
201+
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
201202
r"""
202203
Return the auto policy for the model
203204
@@ -208,7 +209,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
208209
:class:`Policy`: The auto policy for the model
209210
"""
210211
full_name = _fullname(model)
211-
if inference_only:
212+
if ShardConfig.inference_only:
212213
policy_location = _INFER_POLICY_LIST.get(full_name, None)
213214
else:
214215
policy_location = _POLICY_LIST.get(full_name, None)
@@ -218,5 +219,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
218219
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
219220
)
220221
else:
221-
policy = import_policy(policy_location, inference_only)
222-
return policy()
222+
policy = import_policy(policy_location, ShardConfig.inference_only)
223+
return policy(model, shard_config)

0 commit comments

Comments
 (0)