11#!/usr/bin/env python
22# -*- encoding: utf-8 -*-
3+ from abc import ABC , abstractmethod
34
45import torch .nn as nn
56
67from colossalai .lazy import LazyInitContext
78
8- __all__ = ["FusedLayerNorm" , "FusedRMSNorm" ]
9+ from .utils import SeqParallelUtils
10+
11+ __all__ = ["FusedLayerNorm" , "FusedRMSNorm" , "LayerNorm" , "RMSNorm" , "BaseLayerNorm" ]
912
1013FAST_LAYERNORM_SUPPORTED_SIZE = [
1114 1024 ,
3538]
3639
3740
38- class FusedLayerNorm :
41+ class BaseLayerNorm (ABC ):
42+ @abstractmethod
43+ def from_native_module (module : nn .Module , sp_partial_derived : bool = False ):
44+ """
45+ Convert a native PyTorch layer normalization module to a specific layer normalization module,
46+ and optionally mark parameters for gradient aggregation.
47+
48+ Args:
49+ module (nn.Module): The native PyTorch layer normalization module to be converted.
50+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
51+
52+ Returns:
53+ nn.Module: The specific layer normalization module.
54+
55+ Raises:
56+ AssertionError: If the provided module is not an instance of the supported layer normalization type.
57+ """
58+
59+
60+ class RMSNorm (BaseLayerNorm ):
61+ r"""
62+ This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
63+ """
64+
65+ def __init__ (self ) -> None :
66+ raise NotImplementedError (
67+ "FusedLayerNorm is not implemented as a physical class. "
68+ "It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
69+ )
70+
71+ @staticmethod
72+ def from_native_module (module : nn .Module , sp_partial_derived : bool = False , * args , ** kwargs ) -> nn .Module :
73+ """
74+ Convert a native RMSNorm module to colossalai layer norm module,
75+ and optionally mark parameters for gradient aggregation.
76+
77+ Args:
78+ module (nn.Module): The native RMSNorm module to be converted.
79+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
80+
81+ Returns:
82+ nn.Module: The RMSNorm module.
83+ """
84+
85+ LazyInitContext .materialize (module )
86+
87+ if sp_partial_derived :
88+ # Since gradients are computed using only a subset of the data,
89+ # aggregation of these gradients is necessary during backpropagation.
90+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
91+ SeqParallelUtils .marked_as_sp_partial_derived_param (module .weight )
92+
93+ return module
94+
95+
96+ class LayerNorm (BaseLayerNorm ):
97+ r"""
98+ This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
99+ """
100+
101+ def __init__ (self ) -> None :
102+ raise NotImplementedError (
103+ "LayerNorm is not implemented as a physical class. "
104+ "It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
105+ )
106+
107+ @staticmethod
108+ def from_native_module (module : nn .LayerNorm , sp_partial_derived : bool = False , * args , ** kwargs ) -> nn .Module :
109+ r"""
110+ Convert a native pytorch layer norm module to colossalai layer norm module,
111+ and optionally marking parameters for gradient aggregation.
112+
113+ Args:
114+ module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
115+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
116+
117+ Returns:
118+ nn.Module: The LayerNorm module.
119+
120+ Raises:
121+ AssertionError: If the provided module is not an instance of nn.LayerNorm.
122+ """
123+ assert isinstance (module , nn .LayerNorm ), "Only support conversion from nn.LayerNorm."
124+
125+ LazyInitContext .materialize (module )
126+
127+ if sp_partial_derived :
128+ # Since gradients are computed using only a subset of the data,
129+ # aggregation of these gradients is necessary during backpropagation.
130+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
131+ SeqParallelUtils .marked_as_sp_partial_derived_param (module .weight )
132+ SeqParallelUtils .marked_as_sp_partial_derived_param (module .bias )
133+
134+ return module
135+
136+
137+ class FusedLayerNorm (BaseLayerNorm ):
39138 r"""
40139 This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
41140 """
42141
43142 def __init__ (self ) -> None :
44143 raise NotImplementedError (
45144 "FusedLayerNorm is not implemented as a physical class. "
46- "It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
145+ "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
47146 )
48147
49148 @staticmethod
50- def from_native_module (module : nn .LayerNorm , * args , ** kwargs ) -> nn .Module :
149+ def from_native_module (module : nn .LayerNorm , sp_partial_derived : bool = False , * args , ** kwargs ) -> nn .Module :
51150 r"""
52- Convert a native pytorch layer norm module to colossalai layer norm module
151+ Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
152+ and optionally marking parameters for gradient aggregation.
153+
154+ Args:
155+ module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
156+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
157+
158+ Returns:
159+ nn.Module: Union[FastLayerNorm, FusedLayerNorm].
160+
161+ Raises:
162+ AssertionError: If the provided module is not an instance of nn.LayerNorm.
53163 """
54164 # check if apex is installed
165+
166+ assert isinstance (module , nn .LayerNorm ), "Only support conversion from nn.LayerNorm."
167+
55168 try :
56169 pass
57170 except ImportError :
@@ -85,22 +198,41 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
85198
86199 layernorm .weight = module .weight
87200 layernorm .bias = module .bias
201+
202+ if sp_partial_derived :
203+ # Since gradients are computed using only a subset of the data,
204+ # aggregation of these gradients is necessary during backpropagation.
205+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
206+ SeqParallelUtils .marked_as_sp_partial_derived_param (layernorm .weight )
207+ SeqParallelUtils .marked_as_sp_partial_derived_param (layernorm .bias )
208+
88209 return layernorm
89210
90211
91- class FusedRMSNorm :
212+ class FusedRMSNorm ( BaseLayerNorm ) :
92213 """
93214 This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
94215 """
95216
96217 def __init__ (self ) -> None :
97218 raise NotImplementedError (
98219 "FusedRMSNorm is not implemented as a physical class. "
99- "It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
220+ "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
100221 )
101222
102223 @staticmethod
103- def from_native_module (module : nn .Module , * args , ** kwargs ) -> nn .Module :
224+ def from_native_module (module : nn .Module , sp_partial_derived : bool = False , * args , ** kwargs ) -> nn .Module :
225+ r"""
226+ Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
227+ and optionally marking parameters for gradient aggregation.
228+
229+ Args:
230+ module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
231+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
232+
233+ Returns:
234+ nn.Module: FusedRMSNorm module.
235+ """
104236 try :
105237 from apex .normalization import FusedRMSNorm as ApexFusedRMSNorm
106238 except ImportError :
@@ -124,4 +256,10 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
124256
125257 rmsnorm .weight = module .weight
126258
259+ if sp_partial_derived :
260+ # Since gradients are computed using only a subset of the data,
261+ # aggregation of these gradients is necessary during backpropagation.
262+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
263+ SeqParallelUtils .marked_as_sp_partial_derived_param (rmsnorm .weight )
264+
127265 return rmsnorm
0 commit comments