@@ -463,6 +463,10 @@ at::Tensor one_shot_all_reduce_out_impl(
463463 local_input->numel () <= input.numel (),
464464 " one_shot_all_reduce: local input size must be smaller than symm buffer size." );
465465 }
466+ if (input.numel () == 0 ) {
467+ TORCH_CHECK (input.scalar_type () == out.scalar_type ());
468+ return out;
469+ }
466470 auto symm_mem = c10d::symmetric_memory::rendezvous (input, group_name);
467471 TORCH_CHECK (
468472 symm_mem != nullptr ,
@@ -555,9 +559,14 @@ at::Tensor one_shot_all_reduce_copy(
555559}
556560
557561constexpr size_t two_shot_all_reduce_max_num_blocks = 24 ;
558- constexpr size_t two_shot_all_reduce_max_num_threads = 512 ;
559-
560- template <typename T, int alignment, int k_world_size>
562+ constexpr size_t two_shot_all_reduce_max_num_threads = 1024 ;
563+
564+ template <
565+ typename T,
566+ int alignment,
567+ int k_world_size,
568+ bool reduce_scatter = false ,
569+ bool split_last_dim = false >
561570static __launch_bounds__ (two_shot_all_reduce_max_num_threads) __global__
562571 void two_shot_all_reduce_kernel(
563572 T** input_ptrs,
@@ -566,31 +575,48 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
566575 size_t numel,
567576 uint32_t ** signal_pads,
568577 size_t rank,
569- size_t world_size) {
578+ size_t world_size,
579+ size_t last_dim_size = 0 ) {
570580 static_assert (alignment % sizeof (T) == 0 );
571581 constexpr size_t numel_per_thread = alignment / sizeof (T);
572-
582+ int32_t N_last_dim =
583+ last_dim_size / world_size; // used only for split_last_dim reduce_scatter
573584 sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
574585 __syncthreads ();
575586
576587 const size_t numel_per_rank =
577- at::round_up (numel, alignment * world_size) / world_size;
578- const size_t start = numel_per_rank * rank;
588+ at::round_up (numel, numel_per_thread * world_size) / world_size;
589+ const size_t start = split_last_dim ? last_dim_size / world_size * rank
590+ : numel_per_rank * rank;
579591
580592 auto offset = (blockDim .x * blockIdx .x + threadIdx .x ) * numel_per_thread;
581593 auto stride = blockDim .x * gridDim .x * numel_per_thread;
582594 for (size_t i = offset; i < numel_per_rank; i += stride) {
583- if (start + i >= numel) {
584- continue ;
595+ if constexpr (!reduce_scatter) {
596+ // we call reduce-scatter only with evenly divisible number of elements
597+ if (start + i >= numel) {
598+ continue ;
599+ }
600+ }
601+ size_t idx = i;
602+ if constexpr (split_last_dim) {
603+ idx = i / N_last_dim * last_dim_size + i % N_last_dim;
585604 }
586605 auto vec = load_and_reduce<T, alignment, k_world_size>(
587- input_ptrs, rank, world_size, input_offset + start + i);
588- // store to local buffer
589- st_vec<alignment>(input_ptrs[rank] + input_offset + start + i, vec);
606+ input_ptrs, rank, world_size, input_offset + start + idx);
607+ // store to local buffer or to output
608+ if constexpr (reduce_scatter) {
609+ st_vec<alignment>(output_ptr + i, vec);
610+ } else {
611+ st_vec<alignment>(input_ptrs[rank] + input_offset + start + i, vec);
612+ }
590613 }
591614
592615 __syncthreads ();
593616 sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
617+ if constexpr (reduce_scatter) {
618+ return ;
619+ }
594620 __syncthreads ();
595621 for (size_t i = offset; i < numel_per_rank; i += stride) {
596622 Vec<alignment> tmp[k_world_size];
@@ -611,8 +637,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
611637 if (remote_start + i >= numel) {
612638 continue ;
613639 }
614- st_vec<alignment>(
615- output_ptr + remote_start + i, tmp[step]);
640+ st_vec<alignment>(output_ptr + remote_start + i, tmp[step]);
616641 }
617642 }
618643 // need to make sure all blocks exit simultaneously so that the data
@@ -679,11 +704,28 @@ at::Tensor two_shot_all_reduce_impl(
679704 get_and_verify_alignment (input, " two_shot_all_reduce" );
680705
681706 if (output.has_value ()) {
707+ TORCH_CHECK (
708+ output->is_contiguous (),
709+ " two_shot_all_reduce: output must be contiguous." );
682710 const size_t output_alignment =
683711 get_and_verify_alignment (*output, " two_shot_all_reduce" );
684712 TORCH_CHECK (
685713 alignment <= output_alignment,
686714 " two_shot_all_reduce: output alignment must be equal to or larger than input." );
715+ TORCH_CHECK (
716+ output->sizes () == input.sizes (),
717+ " two_shot_all_reduce: input/output size mismatch, input.sizes(): " ,
718+ input.sizes (),
719+ " , output.sizes(): " ,
720+ output->sizes ());
721+ if (input.numel () == 0 ) {
722+ TORCH_CHECK (output->scalar_type () == input.scalar_type ());
723+ return *output;
724+ }
725+ } else {
726+ if (input.numel () == 0 ) {
727+ return input;
728+ }
687729 }
688730
689731 int num_blocks = 0 , num_threads = 0 ;
@@ -764,6 +806,146 @@ at::Tensor two_shot_all_reduce_out(
764806 at::Tensor output) {
765807 return two_shot_all_reduce_impl (input, output, reduce_op, group_name);
766808}
809+
810+ at::Tensor reduce_scatter_out (
811+ at::Tensor input,
812+ std::string group_name,
813+ bool split_last_dim,
814+ at::Tensor output) {
815+ TORCH_CHECK (
816+ input.is_contiguous (), " reduce_scatter: input must be contiguous." );
817+ TORCH_CHECK (
818+ output.is_contiguous (), " reduce_scatter: output must be contiguous." );
819+
820+ auto symm_mem = c10d::symmetric_memory::rendezvous (input, group_name);
821+ TORCH_CHECK (
822+ symm_mem != nullptr ,
823+ " reduce_scatter: input must be allocated with empty_strided_p2p()." );
824+
825+ const size_t alignment = get_and_verify_alignment (input, " reduce_scatter" );
826+
827+ const size_t output_alignment =
828+ get_and_verify_alignment (input, " reduce_scatter" );
829+
830+ TORCH_CHECK (
831+ input.numel () %
832+ (symm_mem->get_world_size () *
833+ (alignment / input.element_size ())) ==
834+ 0 ,
835+ " expected number of elements to be divisible by world_size * alignment, number of elements " ,
836+ input.numel (),
837+ " world size " ,
838+ symm_mem->get_world_size (),
839+ " alignment " ,
840+ alignment);
841+
842+ if (split_last_dim) {
843+ TORCH_CHECK (input.dim () == output.dim ());
844+ bool are_equal_except_last = std::equal (
845+ input.sizes ().begin (), input.sizes ().end () - 1 , output.sizes ().begin ());
846+ TORCH_CHECK (
847+ are_equal_except_last,
848+ " reduce_scatter expected input and output to have same sizes except in the last dimension" );
849+ TORCH_CHECK (
850+ output.size (-1 ) == input.size (-1 ) / symm_mem->get_world_size (),
851+ " reduce_scatter expected output last dim size to be input last dim size / world_size" );
852+
853+ TORCH_CHECK (
854+ input.size (-1 ) %
855+ (symm_mem->get_world_size () *
856+ (alignment / input.element_size ())) ==
857+ 0 ,
858+ " expected last dimension to be divisible by world_size * alignment, last dimension " ,
859+ input.size (-1 ),
860+ " world size " ,
861+ symm_mem->get_world_size (),
862+ " alignment " ,
863+ alignment);
864+ } else {
865+ TORCH_CHECK (input.dim () == 1 , " reduce_scatter expected 1D input" );
866+ TORCH_CHECK (output.dim () == 1 , " reduce_scatter expected 1D output" );
867+ TORCH_CHECK (output.numel () == input.numel () / symm_mem->get_world_size ());
868+ }
869+ if (input.numel () == 0 ) {
870+ TORCH_CHECK (input.scalar_type () == output.scalar_type ());
871+ return output;
872+ }
873+
874+ TORCH_CHECK (
875+ output_alignment >= alignment,
876+ " reduce_scatter: output alignment should be not smaller than input alignment" );
877+
878+ int num_blocks = 0 , num_threads = 0 ;
879+ init_elementwise_launch_config (
880+ input.numel (),
881+ input.element_size (),
882+ alignment,
883+ symm_mem->get_world_size (),
884+ two_shot_all_reduce_max_num_blocks,
885+ two_shot_all_reduce_max_num_threads,
886+ num_blocks,
887+ num_threads);
888+ if (split_last_dim) {
889+ AT_DISPATCH_FLOAT_AND_BFLOAT16 (
890+ input.scalar_type (), " two_shot_all_reduce" , [&]() {
891+ DISPATCH_ALIGNMENTS_16_8_4 (alignment, [&]() {
892+ DISPATCH_WORLD_SIZES_NO_DEFAULT (symm_mem->get_world_size (), [&]() {
893+ two_shot_all_reduce_kernel<
894+ scalar_t ,
895+ k_alignment,
896+ k_world_size,
897+ true ,
898+ true >
899+ <<<num_blocks,
900+ num_threads,
901+ 0 ,
902+ at::cuda::getCurrentCUDAStream ()>>>(
903+ reinterpret_cast <scalar_t **>(
904+ symm_mem->get_buffer_ptrs_dev ()),
905+ output.data_ptr<scalar_t>(),
906+ input.storage_offset(),
907+ input.numel(),
908+ reinterpret_cast<uint32_t**>(
909+ symm_mem->get_signal_pad_ptrs_dev ()),
910+ symm_mem->get_rank(),
911+ symm_mem->get_world_size(),
912+ input.size(-1 ));
913+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
914+ });
915+ });
916+ });
917+ } else {
918+ AT_DISPATCH_FLOAT_AND_BFLOAT16 (
919+ input.scalar_type (), " two_shot_all_reduce" , [&]() {
920+ DISPATCH_ALIGNMENTS_16_8_4 (alignment, [&]() {
921+ DISPATCH_WORLD_SIZES_NO_DEFAULT (symm_mem->get_world_size (), [&]() {
922+ two_shot_all_reduce_kernel<
923+ scalar_t ,
924+ k_alignment,
925+ k_world_size,
926+ true ,
927+ false >
928+ <<<num_blocks,
929+ num_threads,
930+ 0 ,
931+ at::cuda::getCurrentCUDAStream ()>>>(
932+ reinterpret_cast <scalar_t **>(
933+ symm_mem->get_buffer_ptrs_dev ()),
934+ output.data_ptr<scalar_t>(),
935+ input.storage_offset(),
936+ input.numel(),
937+ reinterpret_cast<uint32_t**>(
938+ symm_mem->get_signal_pad_ptrs_dev ()),
939+ symm_mem->get_rank(),
940+ symm_mem->get_world_size(),
941+ input.size(-1 ));
942+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
943+ });
944+ });
945+ });
946+ }
947+ return output;
948+ }
767949} // namespace
768950#endif // #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
769951
@@ -899,6 +1081,7 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
8991081 m.impl (" one_shot_all_reduce_copy_out" , ::one_shot_all_reduce_copy_out);
9001082 m.impl (" two_shot_all_reduce_" , ::two_shot_all_reduce_);
9011083 m.impl (" two_shot_all_reduce_out" , ::two_shot_all_reduce_out);
1084+ m.impl (" reduce_scatter_out" , ::reduce_scatter_out);
9021085
9031086 m.impl (" _async_input_mm" , c10d::cuda::detail::async_input_mm);
9041087#endif
0 commit comments