@@ -124,6 +124,20 @@ def allgather(
124124 """
125125 raise NotImplementedError ("not implemented" )
126126
127+ # pyre-fixme[14]: inconsistent override
128+ def allgather_into_tensor_coalesced (
129+ self ,
130+ output_tensors : List [torch .Tensor ],
131+ input_tensors : List [torch .Tensor ],
132+ opts : AllgatherOptions ,
133+ ) -> Work :
134+ """
135+ Performs an allgather operation on coalesced tensors.
136+
137+ See torch.distributed.allgather_coalesced for more details.
138+ """
139+ raise NotImplementedError ("not implemented" )
140+
127141 # pyre-fixme[14]: inconsistent override
128142 def allreduce (
129143 self ,
@@ -212,6 +226,20 @@ def reduce_scatter(
212226 """
213227 raise NotImplementedError ("not implemented" )
214228
229+ # pyre-fixme[14]: inconsistent override
230+ def reduce_scatter_tensor_coalesced (
231+ self ,
232+ output_tensors : List [torch .Tensor ],
233+ input_tensors : List [torch .Tensor ],
234+ opts : ReduceScatterOptions ,
235+ ) -> Work :
236+ """
237+ Performs a reduce-scatter operation on coalesced tensors.
238+
239+ See torch.distributed.reduce_scatter_tensor for more details.
240+ """
241+ raise NotImplementedError ("not implemented" )
242+
215243 # pyre-fixme[14]: inconsistent override
216244 def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
217245 """
@@ -336,10 +364,20 @@ def allgather(
336364 self ,
337365 output_tensors : List [List [torch .Tensor ]],
338366 input_tensor : List [torch .Tensor ],
339- opts : object ,
367+ opts : AllgatherOptions ,
340368 ) -> Work :
341369 return self .parent .allgather (output_tensors , input_tensor , opts )
342370
371+ def allgather_into_tensor_coalesced (
372+ self ,
373+ output_tensors : List [torch .Tensor ],
374+ input_tensors : List [torch .Tensor ],
375+ opts : AllgatherOptions ,
376+ ) -> Work :
377+ return self .parent .allgather_into_tensor_coalesced (
378+ output_tensors , input_tensors , opts
379+ )
380+
343381 def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
344382 return self .parent .allreduce (tensors , opts )
345383
@@ -377,6 +415,16 @@ def reduce_scatter(
377415 ) -> Work :
378416 return self .parent .reduce_scatter (output_tensors , input_tensors , opts )
379417
418+ def reduce_scatter_tensor_coalesced (
419+ self ,
420+ output_tensors : List [torch .Tensor ],
421+ input_tensors : List [torch .Tensor ],
422+ opts : ReduceScatterOptions ,
423+ ) -> Work :
424+ return self .parent .reduce_scatter_tensor_coalesced (
425+ output_tensors , input_tensors , opts
426+ )
427+
380428 def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
381429 return self .parent .send (tensors , dst_rank , tag )
382430
@@ -402,8 +450,15 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
402450 self ._timeout = timeout
403451
404452 def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
453+ pg = BaseProcessGroup (store , rank , world_size )
454+ pg ._set_default_backend (ProcessGroup .BackendType .GLOO )
405455 # pyre-fixme[16]: no attribute ProcessGroupGloo
406- return BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
456+ backend_class = BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
457+ backend_class ._set_sequence_number_for_group ()
458+ pg ._register_backend (
459+ torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class
460+ )
461+ return pg
407462
408463 def getBackendName (self ) -> str :
409464 return "torchft-gloo"
@@ -427,6 +482,28 @@ def reduce_scatter(
427482 """
428483 raise RuntimeError ("ProcessGroupGloo does not support reduce_scatter." )
429484
485+ # pyre-fixme[15]: inconsistent override
486+ def reduce_scatter_tensor_coalesced (
487+ self ,
488+ output_tensors : List [torch .Tensor ],
489+ input_tensors : List [torch .Tensor ],
490+ opts : ReduceScatterOptions ,
491+ ) -> None :
492+ """
493+ This function is a placeholder for the reduce_scatter_tensor_coalesced
494+ operation in the ProcessGroupGloo class.
495+ However, this operation is not supported by the
496+ Gloo backend, and thus, calling this function will raise a
497+ RuntimeError.
498+
499+ Raises:
500+ RuntimeError: Always raised since reduce_scatter is not
501+ supported by ProcessGroupGloo.
502+ """
503+ raise RuntimeError (
504+ "ProcessGroupGloo does not support reduce_scatter_tensor_coalesced."
505+ )
506+
430507
431508class ProcessGroupNCCL (ProcessGroupWrapper ):
432509 """
@@ -440,8 +517,15 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
440517 """
441518
442519 def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
520+ pg = BaseProcessGroup (store , rank , world_size )
521+ pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
443522 # pyre-fixme[16]: no attribute ProcessGroupNCCL
444- return BaseProcessGroupNCCL (store , rank , world_size )
523+ backend_class = BaseProcessGroupNCCL (store , rank , world_size )
524+ backend_class ._set_sequence_number_for_group ()
525+ pg ._register_backend (
526+ torch .device ("cuda" ), ProcessGroup .BackendType .NCCL , backend_class
527+ )
528+ return pg
445529
446530 def getBackendName (self ) -> str :
447531 return "torchft-nccl"
@@ -499,6 +583,19 @@ def allgather(
499583 self ._work .append (res )
500584 return res
501585
586+ def allgather_into_tensor_coalesced (
587+ self ,
588+ output_tensors : List [torch .Tensor ],
589+ input_tensors : List [torch .Tensor ],
590+ opts : AllgatherOptions ,
591+ ) -> Work :
592+ for o , i in zip (output_tensors , input_tensors ):
593+ o .copy_ (i )
594+
595+ res = _DummyWork (output_tensors )
596+ self ._work .append (res )
597+ return res
598+
502599 def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
503600 res = _DummyWork (tensors )
504601 self ._work .append (res )
@@ -548,6 +645,19 @@ def reduce_scatter(
548645 self ._work .append (res )
549646 return res
550647
648+ def reduce_scatter_tensor_coalesced (
649+ self ,
650+ output_tensors : List [torch .Tensor ],
651+ input_tensors : List [torch .Tensor ],
652+ opts : ReduceScatterOptions ,
653+ ) -> Work :
654+ for o , i in zip (output_tensors , input_tensors ):
655+ o .copy_ (i )
656+
657+ res = _DummyWork (output_tensors )
658+ self ._work .append (res )
659+ return res
660+
551661 def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
552662 return _DummyWork (None )
553663
@@ -1134,6 +1244,20 @@ def allgather(
11341244 _maybe_share_tensors (input_tensor )
11351245 return self ._run_func ("allgather" , output_tensors , input_tensor , opts )
11361246
1247+ def allgather_into_tensor_coalesced (
1248+ self ,
1249+ output_tensors : List [torch .Tensor ],
1250+ input_tensors : List [torch .Tensor ],
1251+ opts : AllgatherOptions ,
1252+ ) -> Work :
1253+ _assert_list (output_tensors )
1254+ _assert_list (input_tensors )
1255+ _maybe_share_tensors (output_tensors )
1256+ _maybe_share_tensors (input_tensors )
1257+ return self ._run_func (
1258+ "allgather_into_tensor_coalesced" , output_tensors , input_tensors , opts
1259+ )
1260+
11371261 def allreduce (
11381262 self ,
11391263 tensors : List [torch .Tensor ],
@@ -1200,6 +1324,20 @@ def reduce_scatter(
12001324 _maybe_share_tensors (input_tensors )
12011325 return self ._run_func ("reduce_scatter" , output_tensors , input_tensors , opts )
12021326
1327+ def reduce_scatter_tensor_coalesced (
1328+ self ,
1329+ output_tensors : List [torch .Tensor ],
1330+ input_tensors : List [torch .Tensor ],
1331+ opts : ReduceScatterOptions ,
1332+ ) -> Work :
1333+ _assert_list (output_tensors )
1334+ _assert_list (input_tensors )
1335+ _maybe_share_tensors (output_tensors )
1336+ _maybe_share_tensors (input_tensors )
1337+ return self ._run_func (
1338+ "reduce_scatter_tensor_coalesced" , output_tensors , input_tensors , opts
1339+ )
1340+
12031341 def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
12041342 _assert_list (tensors )
12051343 _maybe_share_tensors (tensors )
@@ -1278,8 +1416,14 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
12781416
12791417 @classmethod
12801418 def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
1419+ pg = BaseProcessGroup (store , rank , world_size )
1420+ pg ._set_default_backend (ProcessGroup .BackendType .GLOO )
12811421 # pyre-fixme[16]: no attribute ProcessGroupGloo
1282- return BaseProcessGroupGloo (store , rank , world_size )
1422+ backend_class = BaseProcessGroupGloo (store , rank , world_size )
1423+ pg ._register_backend (
1424+ torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class
1425+ )
1426+ return pg
12831427
12841428 def getBackendName (self ) -> str :
12851429 return "torchft-baby-gloo"
@@ -1303,6 +1447,28 @@ def reduce_scatter(
13031447 """
13041448 raise RuntimeError ("ProcessGroupBabyGloo does not support reduce_scatter." )
13051449
1450+ # pyre-fixme[15]: inconsistent override
1451+ def reduce_scatter_tensor_coalesced (
1452+ self ,
1453+ output_tensors : List [torch .Tensor ],
1454+ input_tensors : List [torch .Tensor ],
1455+ opts : ReduceScatterOptions ,
1456+ ) -> None :
1457+ """
1458+ This function is a placeholder for the reduce_scatter_tensor_coalesced
1459+ operation in the ProcessGroupBabyGloo class.
1460+ However, this operation is not supported by the
1461+ Gloo backend, and thus, calling this function will raise a
1462+ RuntimeError.
1463+
1464+ Raises:
1465+ RuntimeError: Always raised since reduce_scatter is not
1466+ supported by ProcessGroupBabyGloo.
1467+ """
1468+ raise RuntimeError (
1469+ "ProcessGroupBabyGloo does not support reduce_scatter_tensor_coalesced."
1470+ )
1471+
13061472
13071473class ProcessGroupBabyNCCL (ProcessGroupBaby ):
13081474 """
@@ -1322,8 +1488,15 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
13221488
13231489 @classmethod
13241490 def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
1491+ pg = BaseProcessGroup (store , rank , world_size )
1492+ pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
13251493 # pyre-fixme[16]: no attribute ProcessGroupNCCL
1326- return BaseProcessGroupNCCL (store , rank , world_size )
1494+ backend_class = BaseProcessGroupNCCL (store , rank , world_size )
1495+ backend_class ._set_sequence_number_for_group ()
1496+ pg ._register_backend (
1497+ torch .device ("cuda" ), ProcessGroup .BackendType .NCCL , backend_class
1498+ )
1499+ return pg
13271500
13281501 def getBackendName (self ) -> str :
13291502 return "torchft-baby-nccl"
0 commit comments