1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import math
1415from typing import Any , Callable , Dict , List , Mapping , Optional , TYPE_CHECKING , Union
1516
1617import torch
3334from pytorch_lightning .strategies .strategy import TBroadcast
3435from pytorch_lightning .trainer .states import TrainerFn
3536from pytorch_lightning .utilities .enums import PrecisionType
36- from pytorch_lightning .utilities .exceptions import MisconfigurationException
3737from pytorch_lightning .utilities .model_helpers import is_overridden
3838from pytorch_lightning .utilities .types import STEP_OUTPUT
3939
4040_COLOSSALAI_AVAILABLE = RequirementCache ("colossalai" )
41+ _COLOSSALAI_GREATER_0_1_10 = RequirementCache ("colossalai>0.1.10" )
4142if TYPE_CHECKING and _COLOSSALAI_AVAILABLE :
4243 with _patch_cuda_is_available ():
4344 from colossalai .utils .model .colo_init_context import ColoInitContext
@@ -130,7 +131,7 @@ def __init__(
130131 force_outputs_fp32 : bool = False ,
131132 gpu_margin_mem_ratio : float = 0.0 ,
132133 chunk_search_range : int = 64 * 1024 ** 2 ,
133- chunk_search_n_grids : int = 1024 ,
134+ chunk_search_n_grids : int = 4096 ,
134135 min_chunk_size : Optional [int ] = None ,
135136 initial_scale : float = 2 ** 16 ,
136137 min_scale : float = 1 ,
@@ -146,7 +147,7 @@ def __init__(
146147 precision_plugin : Optional [ColossalAIPrecisionPlugin ] = None ,
147148 ) -> None :
148149 if not _COLOSSALAI_AVAILABLE :
149- raise MisconfigurationException (
150+ raise ModuleNotFoundError (
150151 "To use the `ColossalAIStrategy`, please install `colossalai` first. "
151152 "Download `colossalai` by consulting `https://colossalai.org/download`."
152153 )
@@ -237,7 +238,8 @@ def _post_init_method(self, module: torch.nn.Module, *args: Any, **kwargs: Any)
237238 if getattr (module , "_colossalai_module" , False ) is True :
238239 return
239240 super ()._post_init_method (module , * args , ** kwargs )
240- module ._colossalai_module = True # type: ignore[assignment]
241+ for sub_module in module .modules ():
242+ sub_module ._colossalai_module = True # type: ignore[assignment]
241243
242244 return ModelShardedContext ()
243245
@@ -264,23 +266,54 @@ def setup_precision_plugin(self) -> None:
264266 )
265267 assert isinstance (self .model , (pl .LightningModule , _LightningPrecisionModuleWrapperBase ))
266268 pl_module = self .model
267- process_group = ProcessGroup ()
269+
268270 if not hasattr (pl_module , "_colossalai_zero" ):
269- if self .use_chunk :
270- chunk_size = self .chunk_size or ChunkManager .search_chunk_size (
271- self .model , ** self .chunk_size_search_kwargs
271+ if not _COLOSSALAI_GREATER_0_1_10 :
272+ if self .use_chunk :
273+ chunk_size = self .chunk_size or ChunkManager .search_chunk_size (
274+ self .model , ** self .chunk_size_search_kwargs
275+ )
276+ else :
277+ chunk_size = None
278+ process_group = ProcessGroup ()
279+ chunk_manager = ChunkManager (
280+ chunk_size ,
281+ process_group ,
282+ self .enable_distributed_storage ,
283+ GeminiManager .get_default_device (self .placement_policy ),
272284 )
285+ gemini_manager = GeminiManager (self .placement_policy , chunk_manager )
286+ model = _LightningModuleWrapperBase (self .model )
287+ self .model = ZeroDDP (model , gemini_manager , self .force_outputs_fp32 )
273288 else :
274- chunk_size = None
275- chunk_manager = ChunkManager (
276- chunk_size ,
277- process_group ,
278- self .enable_distributed_storage ,
279- GeminiManager .get_default_device (self .placement_policy ),
280- )
281- gemini_manager = GeminiManager (self .placement_policy , chunk_manager )
282- model = _LightningModuleWrapperBase (self .model )
283- self .model = ZeroDDP (model , gemini_manager , self .force_outputs_fp32 )
289+ with _patch_cuda_is_available ():
290+ from colossalai .nn .parallel import GeminiDDP
291+ from colossalai .utils import get_current_device
292+ if not self .use_chunk :
293+ raise ValueError ("`ColossalAIStrategy` must use chunk in versions higher than 0.1.10" )
294+ chunk_search_range : int = self .chunk_size_search_kwargs .get (
295+ "search_range" , 32 * 1024 ** 2
296+ ) # type: ignore[assignment]
297+ search_range_mb : float = chunk_search_range / 1024 ** 2
298+ search_n_grids : int = self .chunk_size_search_kwargs .get ("n_grids" , 4096 ) # type: ignore[assignment]
299+ search_interval : int = math .ceil (chunk_search_range / search_n_grids )
300+ min_chunk_size_mb : float = self .chunk_size_search_kwargs .get (
301+ "min_chunk_size" , 32 * 1024 ** 2
302+ ) # type: ignore[assignment]
303+ min_chunk_size_mb /= 1024 ** 2
304+
305+ model = _LightningModuleWrapperBase (self .model )
306+ self .model = GeminiDDP (
307+ module = model ,
308+ device = get_current_device (),
309+ placement_policy = self .placement_policy ,
310+ pin_memory = True ,
311+ force_outputs_fp32 = self .force_outputs_fp32 ,
312+ search_range_mb = search_range_mb ,
313+ hidden_dim = search_interval ,
314+ min_chunk_size_mb = min_chunk_size_mb ,
315+ )
316+
284317 assert self .model is not None
285318 pl_module ._colossalai_zero = [self .model ] # type: ignore[assignment]
286319 else :
@@ -329,10 +362,20 @@ def setup(self, trainer: "pl.Trainer") -> None:
329362 self .accelerator .setup (trainer )
330363 assert self .lightning_module is not None
331364 self .lightning_module ._device = self .root_device
365+ self .ignore_no_grad_parameters (self .root_device )
332366 self .setup_optimizers (trainer )
333367 self .setup_precision_plugin ()
334368 self .model_to_device ()
335369
370+ def ignore_no_grad_parameters (self , running_device : torch .device ) -> None :
371+ # for those parameters with no gradients
372+ # we shold ignore them on DDP and move them to CUDA
373+ assert self .model is not None
374+ for param in self .model .parameters ():
375+ if not param .requires_grad :
376+ setattr (param , "_ddp_to_ignore" , True )
377+ param .data = param .data .to (running_device )
378+
336379 def model_to_device (self ) -> None :
337380 assert self .lightning_module is not None
338381 pl_module = self .lightning_module
0 commit comments