11#!/usr/bin/env python
22# -*- encoding: utf-8 -*-
33import warnings
4+ from abc import ABC , abstractmethod
5+
46import torch .nn as nn
57from colossalai .lazy import LazyInitContext
68from ._operation import hook_paramter_in_backward
79
8- __all__ = ["FusedLayerNorm" , "FusedRMSNorm" ]
10+ from .utils import SeqParallelUtils
11+
12+ __all__ = ["FusedLayerNorm" , "FusedRMSNorm" , "LayerNorm" , "RMSNorm" , "BaseLayerNorm" ]
913
1014try :
1115 from apex .contrib .layer_norm .layer_norm import FastLayerNorm
@@ -77,21 +81,128 @@ def forward(self, input):
7781 return output
7882
7983
80- class FusedLayerNorm :
84+ class BaseLayerNorm (ABC ):
85+ @abstractmethod
86+ def from_native_module (module : nn .Module , sp_partial_derived : bool = False ):
87+ """
88+ Convert a native PyTorch layer normalization module to a specific layer normalization module,
89+ and optionally mark parameters for gradient aggregation.
90+
91+ Args:
92+ module (nn.Module): The native PyTorch layer normalization module to be converted.
93+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
94+
95+ Returns:
96+ nn.Module: The specific layer normalization module.
97+
98+ Raises:
99+ AssertionError: If the provided module is not an instance of the supported layer normalization type.
100+ """
101+
102+
103+ class RMSNorm (BaseLayerNorm ):
104+ r"""
105+ This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
106+ """
107+
108+ def __init__ (self ) -> None :
109+ raise NotImplementedError (
110+ "FusedLayerNorm is not implemented as a physical class. "
111+ "It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
112+ )
113+
114+ @staticmethod
115+ def from_native_module (module : nn .Module , sp_partial_derived : bool = False , * args , ** kwargs ) -> nn .Module :
116+ """
117+ Convert a native RMSNorm module to colossalai layer norm module,
118+ and optionally mark parameters for gradient aggregation.
119+
120+ Args:
121+ module (nn.Module): The native RMSNorm module to be converted.
122+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
123+
124+ Returns:
125+ nn.Module: The RMSNorm module.
126+ """
127+
128+ LazyInitContext .materialize (module )
129+
130+ if sp_partial_derived :
131+ # Since gradients are computed using only a subset of the data,
132+ # aggregation of these gradients is necessary during backpropagation.
133+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
134+ SeqParallelUtils .marked_as_sp_partial_derived_param (module .weight )
135+
136+ return module
137+
138+
139+ class LayerNorm (BaseLayerNorm ):
140+ r"""
141+ This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
142+ """
143+
144+ def __init__ (self ) -> None :
145+ raise NotImplementedError (
146+ "LayerNorm is not implemented as a physical class. "
147+ "It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
148+ )
149+
150+ @staticmethod
151+ def from_native_module (module : nn .LayerNorm , sp_partial_derived : bool = False , * args , ** kwargs ) -> nn .Module :
152+ r"""
153+ Convert a native pytorch layer norm module to colossalai layer norm module,
154+ and optionally marking parameters for gradient aggregation.
155+
156+ Args:
157+ module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
158+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
159+
160+ Returns:
161+ nn.Module: The LayerNorm module.
162+
163+ Raises:
164+ AssertionError: If the provided module is not an instance of nn.LayerNorm.
165+ """
166+ assert isinstance (module , nn .LayerNorm ), "Only support conversion from nn.LayerNorm."
167+
168+ LazyInitContext .materialize (module )
169+
170+ if sp_partial_derived :
171+ # Since gradients are computed using only a subset of the data,
172+ # aggregation of these gradients is necessary during backpropagation.
173+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
174+ SeqParallelUtils .marked_as_sp_partial_derived_param (module .weight )
175+ SeqParallelUtils .marked_as_sp_partial_derived_param (module .bias )
176+
177+ return module
178+
179+
180+ class FusedLayerNorm (BaseLayerNorm ):
81181 r"""
82182 This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
83183 """
84184
85185 def __init__ (self ) -> None :
86186 raise NotImplementedError (
87187 "FusedLayerNorm is not implemented as a physical class. "
88- "It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
188+ "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
89189 )
90190
91191 @staticmethod
92- def from_native_module (module : nn .LayerNorm , * args , ** kwargs ) -> nn .Module :
192+ def from_native_module (module : nn .LayerNorm , sp_partial_derived : bool = False , * args , ** kwargs ) -> nn .Module :
93193 r"""
94- Convert a native pytorch layer norm module to colossalai layer norm module
194+ Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
195+ and optionally marking parameters for gradient aggregation.
196+
197+ Args:
198+ module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
199+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
200+
201+ Returns:
202+ nn.Module: Union[FastLayerNorm, FusedLayerNorm].
203+
204+ Raises:
205+ AssertionError: If the provided module is not an instance of nn.LayerNorm.
95206 """
96207
97208 LazyInitContext .materialize (module )
@@ -120,21 +231,39 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
120231 layernorm .weight = module .weight
121232 layernorm .bias = module .bias
122233
234+ if sp_partial_derived :
235+ # Since gradients are computed using only a subset of the data,
236+ # aggregation of these gradients is necessary during backpropagation.
237+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
238+ SeqParallelUtils .marked_as_sp_partial_derived_param (layernorm .weight )
239+ SeqParallelUtils .marked_as_sp_partial_derived_param (layernorm .bias )
240+
123241 return layernorm
124242
125243
126- class FusedRMSNorm :
244+ class FusedRMSNorm ( BaseLayerNorm ) :
127245 """
128246 This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
129247 """
130248 def __init__ (self ) -> None :
131249 raise NotImplementedError (
132250 "FusedRMSNorm is not implemented as a physical class. "
133- "It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
251+ "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
134252 )
135253
136254 @staticmethod
137- def from_native_module (module : nn .Module , * args , ** kwargs ) -> nn .Module :
255+ def from_native_module (module : nn .Module , sp_partial_derived : bool = False , * args , ** kwargs ) -> nn .Module :
256+ r"""
257+ Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
258+ and optionally marking parameters for gradient aggregation.
259+
260+ Args:
261+ module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
262+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
263+
264+ Returns:
265+ nn.Module: FusedRMSNorm module.
266+ """
138267 LazyInitContext .materialize (module )
139268 # to check if it is huggingface LlamaRMSNorm
140269 if module .__class__ .__name__ == "LlamaRMSNorm" :
@@ -151,4 +280,10 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
151280
152281 rmsnorm .weight = module .weight
153282
283+ if sp_partial_derived :
284+ # Since gradients are computed using only a subset of the data,
285+ # aggregation of these gradients is necessary during backpropagation.
286+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
287+ SeqParallelUtils .marked_as_sp_partial_derived_param (rmsnorm .weight )
288+
154289 return rmsnorm
0 commit comments