@@ -167,47 +167,57 @@ def backward(ctx, grad_output):
167167 return grad_input , grad_weight , grad_bias , None , None , None
168168
169169
170- def _AllgatherLinear (input_ , weight , process_group ):
170+ def _ring_as_gather (func , input_to_gather = None , input_local = None , process_group = None , gather_dim = 1 , keep_item = False ):
171+ # currently only support one single tensor as output
171172 group_size = dist .get_world_size (process_group )
172173 cur_rank = dist .get_rank (process_group )
173174
174- input_shape = input_ .shape
175- weight_shape = weight .shape
176-
177- output_tensors = [torch .empty ((input_shape [0 ], input_shape [1 ], weight_shape [0 ])) for _ in range (group_size )]
175+ #output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
178176
179177 # initialization of ring communication
180- input_shape [1 ]
181178 recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
182179 send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
183- recv_tensor = input_ .clone ()
184- send_tensor = input_ .clone ()
185-
186- recv_op = dist .P2POp (dist .irecv , recv_tensor , recv_rank , group = process_group )
187- send_op = dist .P2POp (dist .isend , send_tensor , send_rank , group = process_group )
188- handles = dist .batch_isend_irecv ([send_op , recv_op ])
180+ recv_tensors = {}
181+ send_tensors = {}
182+ for k , v in input_to_gather .items ():
183+ recv_tensors [k ] = v .clone ()
184+ send_tensors [k ] = v .clone ()
185+
186+ def communicate_step ():
187+ comm_ops = []
188+ for k in recv_tensors :
189+ comm_ops .append (dist .P2POp (dist .irecv , recv_tensors [k ], recv_rank , group = process_group ))
190+ comm_ops .append (dist .P2POp (dist .isend , send_tensors [k ], send_rank , group = process_group ))
191+ return dist .batch_isend_irecv (comm_ops )
192+
193+ def switch_step ():
194+ for k in recv_tensors :
195+ tmp_tensor = send_tensors [k ]
196+ send_tensors [k ] = recv_tensors [k ]
197+ recv_tensors [k ] = tmp_tensor
198+
199+ output_tensors = []
200+
201+ handles = communicate_step ()
189202 # first round: special case, retrive from local tensor
190- output_tensors [ 0 ] = F . linear ( input_ , weight )
203+ output_tensors . append ( func ( ** input_to_gather , ** input_local ) )
191204 for i in range (group_size - 2 ):
192205 for handle in handles :
193206 handle .wait ()
194207
195- tmp_tensor = send_tensor
196- send_tensor = recv_tensor
197- recv_tensor = tmp_tensor
208+ switch_step ()
198209
199- recv_op = dist .P2POp (dist .irecv , recv_tensor , recv_rank , group = process_group )
200- send_op = dist .P2POp (dist .isend , send_tensor , send_rank , group = process_group )
201- handles = dist .batch_isend_irecv ([recv_op , send_op ])
210+ handles = communicate_step ()
202211
203212 # actual computation
204- output_tensors [ i + 1 ] = F . linear ( send_tensor , weight )
213+ output_tensors . append ( func ( ** send_tensors , ** input_local ) )
205214
206215 # final round: special case, no need to send/recv again
207216 for handle in handles :
208217 handle .wait ()
209- output_tensors [group_size - 1 ] = F .linear (recv_tensor , weight )
210- return torch .cat (output_tensors [group_size - cur_rank :] + output_tensors [: group_size - cur_rank ], dim = 1 )
218+ output_tensors .append (func (** recv_tensors , ** input_local ))
219+
220+ return torch .cat (output_tensors [group_size - cur_rank :] + output_tensors [: group_size - cur_rank ], dim = gather_dim )
211221
212222
213223class _GatherForwardReduceScatterBackward (torch .autograd .Function ):
@@ -247,6 +257,41 @@ def backward(ctx, grad_output):
247257 return output , None , None
248258
249259
260+ class _GatherForwardReduceScatterBackward (torch .autograd .Function ):
261+ """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
262+
263+ Args:
264+ input_ (`torch.Tensor`): The input tensor from sequence parallel region.
265+ process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
266+ overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
267+
268+ """
269+
270+ @staticmethod
271+ def forward (ctx , input_ , process_group , dim ):
272+ ctx .process_group = process_group
273+ ctx .dim = dim
274+
275+ return _gather (input_ , dim , process_group )
276+
277+ @staticmethod
278+ def backward (ctx , grad_output ):
279+ dim = ctx .dim
280+ process_group = ctx .process_group
281+
282+ # do reduce-scatter
283+ new_shape = list (grad_output .shape )
284+ assert (
285+ new_shape [dim ] % dist .get_world_size (process_group ) == 0
286+ ), f"The dimension to split ({ new_shape [dim ]} ) is not a multiple of tensor parallel size ({ dist .get_world_size (process_group )} ). "
287+ new_shape [dim ] = new_shape [dim ] // dist .get_world_size (process_group )
288+ grad_list = [item .contiguous () for item in torch .chunk (grad_output , dist .get_world_size (process_group ), dim = dim )]
289+ output = torch .empty (new_shape , dtype = grad_output .dtype , device = grad_output .device )
290+ dist .reduce_scatter (output , grad_list , group = process_group )
291+
292+ return output , None , None
293+
294+
250295class _LinearWithGatherForwardReduceScatterBackward (torch .autograd .Function ):
251296 """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
252297
@@ -258,19 +303,35 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
258303 """
259304
260305 @staticmethod
261- def forward (ctx , input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap = True ):
306+ def forward (ctx , input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap = True , ring = False ):
262307 ctx .save_for_backward (input_ , weight , bias )
263308 ctx .use_bias = bias is not None
264309 ctx .process_group = process_group
265310 ctx .async_grad_reduce_scatter = async_grad_reduce_scatter
266311 ctx .dim = dim
267312 ctx .overlap = overlap
268313
269- if bias is not None :
270- input_parallel = _gather (input_ , dim , process_group )
271- output = F .linear (input_parallel , weight , bias )
314+ if ring is True :
315+ input_to_gather = {}
316+ input_local = {}
317+ input_to_gather ['input' ] = input_
318+ input_local ['weight' ] = weight
319+
320+ output = _ring_as_gather (
321+ F .linear ,
322+ input_to_gather = input_to_gather ,
323+ input_local = input_local ,
324+ process_group = process_group ,
325+ )
326+
327+ if bias is not None :
328+ output += bias
272329 else :
273- output = _AllgatherLinear (input_ , weight , process_group )
330+ input_parallel = _gather (input_ , dim , process_group )
331+ if bias is not None :
332+ output = F .linear (input_parallel , weight , bias )
333+ else :
334+ output = F .linear (input_parallel , weight )
274335
275336 return output
276337
@@ -373,34 +434,43 @@ def backward(ctx, grad_output):
373434 # wait until reduce-scatter finished
374435 reducescatter_handle .wait ()
375436
376- return output , grad_weight , grad_bias , None , None , None , None
437+ return output , grad_weight , grad_bias , None , None , None , None , None
377438
378439
379- def _ReduceScatterLinear (input_ , weight , process_group ):
440+ def _ring_as_reducescatter (func , input_to_reducescatter = None , input_local = None , process_group = None , reducescatter_dim = 1 ):
441+ # currently only support one single tensor as output
380442 group_size = dist .get_world_size (process_group )
381443 cur_rank = dist .get_rank (process_group )
382444
383- input_shape = input_ .shape
384-
385445 # initialization of ring communication
386- # communicate(e.g.): 0->1->2->3
387- # compute(e.g.): 3->2->1->0
388- input_tensors = list (torch .split (input_ , int (input_shape [1 ] / group_size ), dim = 1 ))
389- input_tensors = input_tensors [cur_rank :] + input_tensors [:cur_rank ]
390- input_tensors .reverse ()
391446 recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
392447 send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
448+ input_tensors = []
449+ for _ in range (group_size ):
450+ input_tensors .append ({})
451+ for k , v in input_to_reducescatter .items ():
452+ input_shape = v .shape
453+ assert input_shape [reducescatter_dim ] % group_size == 0
454+ _input_tensors = list (torch .split (v , input_shape [reducescatter_dim ] // group_size , dim = reducescatter_dim ))
455+ for i in range (group_size ):
456+ input_tensors [i ][k ] = _input_tensors [i ]
457+ input_tensors = input_tensors [cur_rank :] + input_tensors [:cur_rank ]
458+ input_tensors .reverse ()
393459
394- # first round: special case, no reduce operation
395- output_tensor = F .linear (input_tensors [0 ], weight )
460+ output_tensor = func (** input_tensors [0 ], ** input_local )
396461 recv_tensor = output_tensor .clone ()
397462 send_tensor = output_tensor .clone ()
398- recv_op = dist .P2POp (dist .irecv , recv_tensor , recv_rank , group = process_group )
399- send_op = dist .P2POp (dist .isend , send_tensor , send_rank , group = process_group )
400- handles = dist .batch_isend_irecv ([recv_op , send_op ])
463+
464+ def communicate_step ():
465+ recv_op = dist .P2POp (dist .irecv , recv_tensor , recv_rank , group = process_group )
466+ send_op = dist .P2POp (dist .isend , send_tensor , send_rank , group = process_group )
467+ return dist .batch_isend_irecv ([recv_op , send_op ])
468+
469+ handles = communicate_step ()
470+ # first round: special case, retrive from local tensor
401471 for i in range (group_size - 2 ):
402472 # actual computation
403- output_tensor = F . linear ( input_tensors [i + 1 ], weight )
473+ output_tensor = func ( ** input_tensors [i + 1 ], ** input_local )
404474
405475 for handle in handles :
406476 handle .wait ()
@@ -410,12 +480,10 @@ def _ReduceScatterLinear(input_, weight, process_group):
410480 send_tensor = output_tensor
411481 output_tensor = tmp_tensor
412482
413- recv_op = dist .P2POp (dist .irecv , recv_tensor , recv_rank , group = process_group )
414- send_op = dist .P2POp (dist .isend , send_tensor , send_rank , group = process_group )
415- handles = dist .batch_isend_irecv ([recv_op , send_op ])
483+ handles = communicate_step ()
416484
417485 # final round: special case, no need to send/recv again
418- output_tensor = F . linear ( input_tensors [group_size - 1 ], weight )
486+ output_tensor = func ( ** input_tensors [- 1 ], ** input_local )
419487 for handle in handles :
420488 handle .wait ()
421489 output_tensor += recv_tensor
@@ -433,27 +501,44 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
433501 """
434502
435503 @staticmethod
436- def forward (ctx , input_ , weight , bias , process_group , dim ):
504+ def forward (ctx , input_ , weight , bias , process_group , dim , ring ):
437505 ctx .save_for_backward (input_ , weight , bias )
438506 ctx .use_bias = bias is not None
439507 ctx .process_group = process_group
440508 ctx .dim = dim
441- if bias is not None :
442- partial_output = F .linear (input_ , weight , bias )
509+
510+ if ring is True :
511+ input_to_reducescatter = {}
512+ input_local = {}
513+ input_to_reducescatter ['input' ] = input_
514+ input_local ['weight' ] = weight
515+
516+ if bias is not None :
517+ input_to_reducescatter ['bias' ] = bias
518+
519+ output = _ring_as_reducescatter (
520+ F .linear ,
521+ input_to_reducescatter = input_to_reducescatter ,
522+ input_local = input_local ,
523+ process_group = process_group ,
524+ )
443525 else :
444- return _ReduceScatterLinear (input_ , weight , process_group )
526+ if bias is not None :
527+ partial_output = F .linear (input_ , weight , bias )
528+ else :
529+ partial_output = F .linear (input_ , weight )
445530
446- output_shape = list (partial_output .shape )
447- assert (
448- output_shape [dim ] % dist .get_world_size (process_group ) == 0
449- ), f"The dimension to split ({ output_shape [dim ]} ) is not a multiple of tensor parallel size ({ dist .get_world_size (process_group )} ). "
450- output_shape [dim ] = output_shape [dim ] // dist .get_world_size (process_group )
531+ output_shape = list (partial_output .shape )
532+ assert (
533+ output_shape [dim ] % dist .get_world_size (process_group ) == 0
534+ ), f"The dimension to split ({ output_shape [dim ]} ) is not a multiple of tensor parallel size ({ dist .get_world_size (process_group )} ). "
535+ output_shape [dim ] = output_shape [dim ] // dist .get_world_size (process_group )
451536
452- output_list = [
453- item .contiguous () for item in torch .chunk (partial_output , dist .get_world_size (process_group ), dim = dim )
454- ]
455- output = torch .empty (output_shape , dtype = partial_output .dtype , device = partial_output .device ).contiguous ()
456- dist .reduce_scatter (output , output_list , group = process_group )
537+ output_list = [
538+ item .contiguous () for item in torch .chunk (partial_output , dist .get_world_size (process_group ), dim = dim )
539+ ]
540+ output = torch .empty (output_shape , dtype = partial_output .dtype , device = partial_output .device ).contiguous ()
541+ dist .reduce_scatter (output , output_list , group = process_group )
457542
458543 return output
459544
@@ -481,7 +566,7 @@ def backward(ctx, grad_output):
481566 grad_weight = grad_output .t ().matmul (total_input )
482567 grad_bias = grad_output .sum (dim = 0 ) if use_bias else None
483568
484- return grad_input , grad_weight , grad_bias , None , None
569+ return grad_input , grad_weight , grad_bias , None , None , None
485570
486571
487572class _ReduceScatterForwardGatherBackward (torch .autograd .Function ):
@@ -530,17 +615,32 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
530615 """
531616
532617 @staticmethod
533- def forward (ctx , input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap ):
618+ def forward (ctx , input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap , ring ):
534619 ctx .save_for_backward (input_ , weight , bias )
535620 ctx .use_bias = bias is not None
536621 ctx .process_group = process_group
537622 ctx .async_grad_reduce_scatter = async_grad_reduce_scatter
538623 ctx .dim = dim
539624 ctx .overlap = overlap
540625
541- input_parallel = _gather (input_ , dim , process_group )
626+ if ring is True :
627+ input_to_gather = {}
628+ input_local = {}
629+ input_to_gather ['input' ] = input_
630+ input_local ['other' ] = weight
631+
632+ output = _ring_as_gather (
633+ torch .matmul ,
634+ input_to_gather = input_to_gather ,
635+ input_local = input_local ,
636+ process_group = process_group ,
637+ gather_dim = dim
638+ )
639+
640+ else :
641+ input_parallel = _gather (input_ , dim , process_group )
542642
543- output = torch .matmul (input_parallel , weight )
643+ output = torch .matmul (input_parallel , weight )
544644
545645 if bias is not None :
546646 output = output + bias
@@ -620,7 +720,7 @@ def backward(ctx, grad_output):
620720 # wait until reduce-scatter finished
621721 reducescatter_handle .wait ()
622722
623- return output , grad_weight , grad_bias , None , None , None , None
723+ return output , grad_weight , grad_bias , None , None , None , None , None
624724
625725
626726class _SplitForwardGatherBackward (torch .autograd .Function ):
@@ -873,10 +973,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
873973
874974
875975def linear_gather_forward_reducescatter_backward (
876- input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap
976+ input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap , ring = False
877977):
878978 return _LinearWithGatherForwardReduceScatterBackward .apply (
879- input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap
979+ input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap , ring
880980 )
881981
882982
@@ -888,15 +988,15 @@ def reducescatter_forward_gather_backward(input_, process_group, dim):
888988 return _ReduceScatterForwardGatherBackward .apply (input_ , process_group , dim )
889989
890990
891- def linear_reducescatter_forward_gather_backward (input_ , weight , bias = None , process_group = None , dim = 1 ):
892- return _LinearWithReduceScatterForwardGatherBackward .apply (input_ , weight , bias , process_group , dim )
991+ def linear_reducescatter_forward_gather_backward (input_ , weight , bias = None , process_group = None , dim = 1 , ring = False ):
992+ return _LinearWithReduceScatterForwardGatherBackward .apply (input_ , weight , bias , process_group , dim , ring )
893993
894994
895995def matmul_gather_forward_reducescatter_backward (
896- input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap
996+ input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap , ring = False
897997):
898998 return _MatmulWithGatherForwardReduceScatterBackward .apply (
899- input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap
999+ input_ , weight , bias , process_group , async_grad_reduce_scatter , dim , overlap , ring
9001000 )
9011001
9021002
0 commit comments