@@ -46,7 +46,7 @@ def __init__(
4646 reduce_bucket_size : int = 1024 * 1024 , # communication
4747 communication_dtype : Optional [torch .dtype ] = None ,
4848 overlap_communication : bool = False ,
49- partition_grad : bool = False , # stage 2
49+ partition_grad : bool = False , # stage 2 flag
5050 cpu_offload : bool = False , # cpu offload
5151 forced_dtype : Optional [torch .dtype ] = None ):
5252
@@ -248,9 +248,13 @@ def _partition_param_list(self, param_list):
248248 self ._logger .info (f'Number of elements on ranks: { numel_per_rank } ' , ranks = [0 ])
249249 return params_per_rank
250250
251- ###########################################################
252- # Backward Reduction Hook
253- ###########################################################
251+ ###########################
252+ # Backward Reduction Hook #
253+ ###########################
254+
255+ def _grad_handler (self , param , grad , reduce_rank ):
256+ self ._add_to_reduction_bucket (param , reduce_rank )
257+ return grad
254258
255259 def _attach_reduction_hook (self ):
256260 # we iterate over the fp16 params
@@ -268,53 +272,61 @@ def _attach_reduction_hook(self):
268272 else :
269273 reduce_rank = None
270274
271- def _define_and_attach (param , reduce_rank ):
272- # get the AccumulateGrad object of the param itself
273- accum_grad_obj = get_grad_accumulate_object (param )
274- self ._grad_store .add_accumulate_grad_object (accum_grad_obj )
275+ param .register_hook (partial (self ._grad_handler , param , reduce_rank = reduce_rank ))
275276
276- reduction_func = partial (self ._reduce_and_remove_grads_by_bucket ,
277- param = param ,
278- reduce_rank = reduce_rank )
277+ def _reduce_tensor_bucket (self , bucket : TensorBucket , reduce_rank ):
278+ if self ._overlap_communication :
279+ torch .cuda .synchronize ()
280+ self ._param_store .clear_grads_of_previous_reduced_params ()
281+ stream = self ._comm_stream
282+ else :
283+ stream = torch .cuda .current_stream ()
279284
280- # define hook
281- # NOT IMPORTANT BUT GOOD TO KNOW:
282- # args here is not grad, but allow_unreacable and accumulate_grad
283- def reduce_grad_hook (* args ):
284- reduction_func ()
285+ with torch .cuda .stream (stream ):
286+ flat = bucket .flatten ()
287+ reduce_global_rank = None
288+ if reduce_rank is not None :
289+ reduce_global_rank = self ._dp_global_ranks [reduce_rank ]
290+ reduced_flat = reduce_tensor_dp_group (tensor = flat ,
291+ dtype = self ._communication_dtype ,
292+ dst_local_rank = reduce_rank ,
293+ dst_global_rank = reduce_global_rank ,
294+ group = self ._dp_torch_group )
285295
286- accum_grad_obj .register_hook (reduce_grad_hook )
296+ # update the reduced tensor
297+ if reduce_rank is None or reduce_rank == self ._local_rank :
298+ bucket .unflatten_and_copy (reduced_flat )
287299
288- _define_and_attach (param , reduce_rank )
300+ def _reduce_tensor_list_with_one_dtype (self , tensor_list , bucket_size , reduce_rank ):
301+ param_bucket = TensorBucket (size = bucket_size )
289302
290- def _reduce_and_remove_grads_by_bucket ( self , param , reduce_rank = None ) :
291- param_size = param . numel ( )
303+ for tensor in tensor_list :
304+ param_bucket . add_to_bucket ( tensor , allow_oversize = True )
292305
293- # check if the bucket is full
294- # if full, will reduce the grads already in the bucket
295- # after reduction, the bucket will be empty
296- if self ._bucket_store .num_elements_in_bucket (reduce_rank ) + param_size > self ._reduce_bucket_size :
297- self ._reduce_grads_in_bucket (reduce_rank )
306+ if param_bucket .is_full_or_oversized ():
307+ self ._reduce_tensor_bucket (bucket = param_bucket , reduce_rank = reduce_rank )
308+ param_bucket .empty ()
298309
299- # the param must not be reduced to ensure correctness
300- is_param_reduced = self ._param_store .is_param_reduced (param )
301- if is_param_reduced :
302- msg = f'Parameter of size ({ param .size ()} ) has already been reduced, ' \
303- + 'duplicate reduction will lead to arithmetic incorrectness'
304- raise RuntimeError (msg )
310+ if not param_bucket .is_empty ():
311+ self ._reduce_tensor_bucket (bucket = param_bucket , reduce_rank = reduce_rank )
305312
306- # the param must have grad for reduction
307- assert param . grad is not None , f'Parameter of size ( { param . size () } ) has None grad, cannot be reduced'
313+ def _reduce_grads ( self , reduce_rank , grads , bucket_size ):
314+ grad_buckets_by_dtype = split_half_float_double ( grads )
308315
309- self ._bucket_store .add_num_elements_in_bucket (param_size , reduce_rank )
310- self ._bucket_store .add_grad (param .grad , reduce_rank )
311- self ._bucket_store .add_param (param , reduce_rank )
316+ for tensor_list in grad_buckets_by_dtype :
317+ self ._reduce_tensor_list_with_one_dtype (tensor_list = tensor_list ,
318+ bucket_size = bucket_size ,
319+ reduce_rank = reduce_rank )
320+
321+ #######################
322+ # Reduction Functions #
323+ #######################
312324
313- def _reduce_grads_in_bucket (self , reduce_rank = None ):
325+ def _run_reduction (self , reduce_rank = None ):
314326 # reduce grads
315- self ._reduce_grads_by_rank (reduce_rank = reduce_rank ,
316- grads = self ._bucket_store .get_grad (reduce_rank = reduce_rank ),
317- bucket_size = self ._bucket_store .num_elements_in_bucket (reduce_rank ))
327+ self ._reduce_grads (reduce_rank = reduce_rank ,
328+ grads = self ._bucket_store .get_grad (reduce_rank = reduce_rank ),
329+ bucket_size = self ._bucket_store .num_elements_in_bucket (reduce_rank ))
318330
319331 # use communication stream if overlapping
320332 # communication with computation
@@ -351,50 +363,24 @@ def _reduce_grads_in_bucket(self, reduce_rank=None):
351363
352364 self ._bucket_store .reset_by_rank (reduce_rank )
353365
354- def _reduce_grads_by_rank (self , reduce_rank , grads , bucket_size ):
355- grad_buckets_by_dtype = split_half_float_double (grads )
356-
357- for tensor_list in grad_buckets_by_dtype :
358- self ._reduce_no_retain (tensor_list = tensor_list , bucket_size = bucket_size , reduce_rank = reduce_rank )
359-
360- ##############################
361- # Reduction Utility Function #
362- ##############################
363- def _reduce_no_retain (self , tensor_list , bucket_size , reduce_rank ):
364- param_bucket = TensorBucket (size = bucket_size )
365-
366- for tensor in tensor_list :
367- param_bucket .add_to_bucket (tensor , allow_oversize = True )
368-
369- if param_bucket .is_full_or_oversized ():
370- self ._reduce_and_copy (bucket = param_bucket , reduce_rank = reduce_rank )
371- param_bucket .empty ()
372-
373- if not param_bucket .is_empty ():
374- self ._reduce_and_copy (bucket = param_bucket , reduce_rank = reduce_rank )
366+ def _add_to_reduction_bucket (self , param , reduce_rank = None ):
367+ param_size = param .numel ()
375368
376- def _reduce_and_copy (self , bucket : TensorBucket , reduce_rank ):
377- if self ._overlap_communication :
378- torch .cuda .synchronize ()
379- self ._param_store .clear_grads_of_previous_reduced_params ()
380- stream = self ._comm_stream
381- else :
382- stream = torch .cuda .current_stream ()
369+ # check if the bucket is full
370+ # if full, will reduce the grads already in the bucket
371+ # after reduction, the bucket will be empty
372+ if self ._bucket_store .num_elements_in_bucket (reduce_rank ) + param_size > self ._reduce_bucket_size :
373+ self ._run_reduction (reduce_rank )
383374
384- with torch .cuda .stream (stream ):
385- flat = bucket .flatten ()
386- reduce_global_rank = None
387- if reduce_rank is not None :
388- reduce_global_rank = self ._dp_global_ranks [reduce_rank ]
389- reduced_flat = reduce_tensor_dp_group (tensor = flat ,
390- dtype = self ._communication_dtype ,
391- dst_local_rank = reduce_rank ,
392- dst_global_rank = reduce_global_rank ,
393- group = self ._dp_torch_group )
375+ # the param must not be reduced to ensure correctness
376+ is_param_reduced = self ._param_store .is_param_reduced (param )
377+ if is_param_reduced :
378+ msg = f'Parameter of size ({ param .size ()} ) has already been reduced, ' \
379+ + 'duplicate reduction will lead to arithmetic incorrectness'
380+ raise RuntimeError (msg )
394381
395- # update the reduced tensor
396- if reduce_rank is None or reduce_rank == self ._local_rank :
397- bucket .unflatten_and_copy (reduced_flat )
382+ self ._bucket_store .add_num_elements_in_bucket (param_size , reduce_rank )
383+ self ._bucket_store .add_param (param , reduce_rank )
398384
399385 ################################
400386 # torch.optim.Optimizer methods
@@ -498,8 +484,9 @@ def step(self, closure=None):
498484 # broadcast the updated model weights
499485 handles = []
500486 for group_id in range (self .num_param_groups ):
501- for rank in range (self ._world_size ):
502- fp16_param = self ._param_store .get_flat_fp16_param_by_rank_group (rank = rank , group_id = group_id )
487+ for index in range (self ._world_size ):
488+ rank = self ._dp_global_ranks [index ]
489+ fp16_param = self ._param_store .get_flat_fp16_param_by_rank_group (rank = index , group_id = group_id )
503490 handle = dist .broadcast (fp16_param , src = rank , group = self ._dp_torch_group , async_op = True )
504491 handles .append (handle )
505492
@@ -585,16 +572,16 @@ def _reduce_grad_stage1(self):
585572 param_group = self ._fp16_param_groups [group_id ]
586573 for param in param_group :
587574 if param .grad is not None :
588- self ._reduce_and_remove_grads_by_bucket (param )
575+ self ._add_to_reduction_bucket (param )
589576
590577 # we need to reduce the gradients
591578 # left in the communication bucket
592- self ._reduce_grads_in_bucket ()
579+ self ._run_reduction ()
593580
594581 def _reduce_grad_stage2 (self ):
595582 # when partition_grads is True, reduction hooks
596583 # are attached in the __init__ function, so we
597584 # only need to reduce the gradients
598585 # left in the communication bucket
599586 for reduce_rank in range (self ._world_size ):
600- self ._reduce_grads_in_bucket (reduce_rank )
587+ self ._run_reduction (reduce_rank )
0 commit comments