Skip to content

Commit d5cc628

Browse files
ngimeltimocafe
authored andcommitted
add reduce_scatter to symm mem ops (pytorch#150813)
+ a few small fixes (don't error out on 0-element tensors, a few more checks for contiguous outputs, more threads for better perf). Pull Request resolved: pytorch#150813 Approved by: https://github.com/xw285cornell
1 parent 7879dda commit d5cc628

File tree

4 files changed

+276
-17
lines changed

4 files changed

+276
-17
lines changed

test/distributed/test_symmetric_memory.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def test_subgroup(self) -> None:
771771
self.assertTrue(buf.eq(peer_rank + world.size() // 2).all())
772772

773773

774-
@skipIfRocm
774+
# @skipIfRocm
775775
@instantiate_parametrized_tests
776776
@requires_cuda_p2p_access()
777777
class SymmMemCollectiveTest(MultiProcessTestCase):
@@ -912,7 +912,7 @@ def test_two_shot_all_reduce(self) -> None:
912912
shift = align_bytes // t.element_size()
913913
numel = size_bytes // t.element_size()
914914
res = t[shift : shift + numel]
915-
res.normal_().fill_(1)
915+
res.normal_()
916916
inp = res.clone()
917917
if not inplace:
918918
out = torch.empty_like(inp)
@@ -940,6 +940,78 @@ def _verify_all_reduce_result(self, inp, res):
940940
gathered_inps.sum(dim=0), res, rtol=1e-01, atol=1e-01
941941
)
942942

943+
@skipIfRocm
944+
@skip_if_lt_x_gpu(4)
945+
def test_reduce_scatter(self) -> None:
946+
self._init_process()
947+
group_name = dist.group.WORLD.group_name
948+
949+
for dtype, size_bytes, align_bytes, split_last_dim in itertools.product(
950+
[torch.float, torch.bfloat16],
951+
[128, 8192, 36 * 1024 * 16],
952+
[4, 8, 16],
953+
[True, False],
954+
):
955+
t = symm_mem.empty(36 * 1024 * 16, dtype=dtype, device=self.device).fill_(0)
956+
symm_mem.rendezvous(t, group=group_name)
957+
958+
self.assertTrue(t.data_ptr() % 16 == 0)
959+
self.assertTrue(align_bytes % t.element_size() == 0)
960+
self.assertTrue(size_bytes % t.element_size() == 0)
961+
962+
shift = align_bytes // t.element_size()
963+
numel = size_bytes // t.element_size()
964+
res = t[shift : shift + numel].normal_()
965+
if split_last_dim:
966+
res = res.view(-1, 128 // t.element_size())
967+
inp = res.clone()
968+
out_size = list(inp.shape)
969+
out_size[-1] = inp.shape[-1] // self.world_size
970+
out = torch.empty(out_size, dtype=dtype, device=self.device)
971+
torch.ops.symm_mem.reduce_scatter_out(res, group_name, split_last_dim, out)
972+
973+
# Head and tail should not be written
974+
self.assertTrue(t[:shift].eq(0).all().item())
975+
self.assertTrue(t[shift + numel :].eq(0).all().item())
976+
self._verify_reduce_scatter_result(inp, out)
977+
978+
dist.destroy_process_group()
979+
980+
@skipIfRocm
981+
@skip_if_lt_x_gpu(4)
982+
def test_reduce_scatter_corner_cases(self) -> None:
983+
dtype = torch.bfloat16
984+
self._init_process()
985+
group_name = dist.group.WORLD.group_name
986+
t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0)
987+
symm_mem.rendezvous(t, group=group_name)
988+
res = t[:0]
989+
out_size = res.shape[0] // self.world_size
990+
out = torch.empty(out_size, dtype=dtype, device=self.device)
991+
torch.ops.symm_mem.reduce_scatter_out(res, group_name, False, out)
992+
res = t[:48]
993+
out_size = res.shape[0] // self.world_size
994+
out = torch.empty(out_size, dtype=dtype, device=self.device)
995+
with self.assertRaisesRegex(RuntimeError, "divisible"):
996+
torch.ops.symm_mem.reduce_scatter_out(res, group_name, False, out)
997+
res = t[: 2 * 48].view(2, 48)
998+
out = torch.empty(2, 48 // self.world_size, dtype=dtype, device=self.device)
999+
with self.assertRaisesRegex(RuntimeError, "divisible"):
1000+
torch.ops.symm_mem.reduce_scatter_out(res, group_name, True, out)
1001+
1002+
def _verify_reduce_scatter_result(self, inp, res):
1003+
gathered_res = all_gather_tensor(res, 0, "0").view(self.world_size, *res.shape)
1004+
gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, *inp.shape)
1005+
sum_inps = gathered_inps.sum(0)
1006+
slice_width = sum_inps.shape[-1] // self.world_size
1007+
for i in range(self.world_size):
1008+
torch.testing.assert_close(
1009+
gathered_res[i],
1010+
sum_inps[..., i * slice_width : (i + 1) * slice_width],
1011+
rtol=1e-01,
1012+
atol=1e-01,
1013+
)
1014+
9431015
@skip_if_lt_x_gpu(4)
9441016
@parametrize("align_bytes", [4, 8, 16])
9451017
def test_multimem_all_gather(self, align_bytes: int) -> None:

torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ __device__ __inline__ Vec<Alignment> ld_vec(const T* addr) {
314314

315315
template <int Alignment, typename T>
316316
__device__ __inline__ void st_vec(T* addr, const Vec<Alignment>& vec) {
317-
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
317+
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
318318
CUDA_KERNEL_ASSERT(false);
319319
#else
320320
if constexpr (Alignment == 16) {

torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu

Lines changed: 197 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,10 @@ at::Tensor one_shot_all_reduce_out_impl(
463463
local_input->numel() <= input.numel(),
464464
"one_shot_all_reduce: local input size must be smaller than symm buffer size.");
465465
}
466+
if (input.numel() == 0) {
467+
TORCH_CHECK(input.scalar_type() == out.scalar_type());
468+
return out;
469+
}
466470
auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name);
467471
TORCH_CHECK(
468472
symm_mem != nullptr,
@@ -555,9 +559,14 @@ at::Tensor one_shot_all_reduce_copy(
555559
}
556560
557561
constexpr size_t two_shot_all_reduce_max_num_blocks = 24;
558-
constexpr size_t two_shot_all_reduce_max_num_threads = 512;
559-
560-
template <typename T, int alignment, int k_world_size>
562+
constexpr size_t two_shot_all_reduce_max_num_threads = 1024;
563+
564+
template <
565+
typename T,
566+
int alignment,
567+
int k_world_size,
568+
bool reduce_scatter = false,
569+
bool split_last_dim = false>
561570
static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
562571
void two_shot_all_reduce_kernel(
563572
T** input_ptrs,
@@ -566,31 +575,48 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
566575
size_t numel,
567576
uint32_t** signal_pads,
568577
size_t rank,
569-
size_t world_size) {
578+
size_t world_size,
579+
size_t last_dim_size = 0) {
570580
static_assert(alignment % sizeof(T) == 0);
571581
constexpr size_t numel_per_thread = alignment / sizeof(T);
572-
582+
int32_t N_last_dim =
583+
last_dim_size / world_size; // used only for split_last_dim reduce_scatter
573584
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
574585
__syncthreads();
575586
576587
const size_t numel_per_rank =
577-
at::round_up(numel, alignment * world_size) / world_size;
578-
const size_t start = numel_per_rank * rank;
588+
at::round_up(numel, numel_per_thread * world_size) / world_size;
589+
const size_t start = split_last_dim ? last_dim_size / world_size * rank
590+
: numel_per_rank * rank;
579591
580592
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
581593
auto stride = blockDim.x * gridDim.x * numel_per_thread;
582594
for (size_t i = offset; i < numel_per_rank; i += stride) {
583-
if (start + i >= numel) {
584-
continue;
595+
if constexpr (!reduce_scatter) {
596+
// we call reduce-scatter only with evenly divisible number of elements
597+
if (start + i >= numel) {
598+
continue;
599+
}
600+
}
601+
size_t idx = i;
602+
if constexpr (split_last_dim) {
603+
idx = i / N_last_dim * last_dim_size + i % N_last_dim;
585604
}
586605
auto vec = load_and_reduce<T, alignment, k_world_size>(
587-
input_ptrs, rank, world_size, input_offset + start + i);
588-
// store to local buffer
589-
st_vec<alignment>(input_ptrs[rank] + input_offset + start + i, vec);
606+
input_ptrs, rank, world_size, input_offset + start + idx);
607+
// store to local buffer or to output
608+
if constexpr (reduce_scatter) {
609+
st_vec<alignment>(output_ptr + i, vec);
610+
} else {
611+
st_vec<alignment>(input_ptrs[rank] + input_offset + start + i, vec);
612+
}
590613
}
591614
592615
__syncthreads();
593616
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
617+
if constexpr (reduce_scatter) {
618+
return;
619+
}
594620
__syncthreads();
595621
for (size_t i = offset; i < numel_per_rank; i += stride) {
596622
Vec<alignment> tmp[k_world_size];
@@ -611,8 +637,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
611637
if (remote_start + i >= numel) {
612638
continue;
613639
}
614-
st_vec<alignment>(
615-
output_ptr + remote_start + i, tmp[step]);
640+
st_vec<alignment>(output_ptr + remote_start + i, tmp[step]);
616641
}
617642
}
618643
// need to make sure all blocks exit simultaneously so that the data
@@ -679,11 +704,28 @@ at::Tensor two_shot_all_reduce_impl(
679704
get_and_verify_alignment(input, "two_shot_all_reduce");
680705
681706
if (output.has_value()) {
707+
TORCH_CHECK(
708+
output->is_contiguous(),
709+
"two_shot_all_reduce: output must be contiguous.");
682710
const size_t output_alignment =
683711
get_and_verify_alignment(*output, "two_shot_all_reduce");
684712
TORCH_CHECK(
685713
alignment <= output_alignment,
686714
"two_shot_all_reduce: output alignment must be equal to or larger than input.");
715+
TORCH_CHECK(
716+
output->sizes() == input.sizes(),
717+
"two_shot_all_reduce: input/output size mismatch, input.sizes(): ",
718+
input.sizes(),
719+
", output.sizes(): ",
720+
output->sizes());
721+
if (input.numel() == 0) {
722+
TORCH_CHECK(output->scalar_type() == input.scalar_type());
723+
return *output;
724+
}
725+
} else {
726+
if (input.numel() == 0) {
727+
return input;
728+
}
687729
}
688730
689731
int num_blocks = 0, num_threads = 0;
@@ -764,6 +806,146 @@ at::Tensor two_shot_all_reduce_out(
764806
at::Tensor output) {
765807
return two_shot_all_reduce_impl(input, output, reduce_op, group_name);
766808
}
809+
810+
at::Tensor reduce_scatter_out(
811+
at::Tensor input,
812+
std::string group_name,
813+
bool split_last_dim,
814+
at::Tensor output) {
815+
TORCH_CHECK(
816+
input.is_contiguous(), "reduce_scatter: input must be contiguous.");
817+
TORCH_CHECK(
818+
output.is_contiguous(), "reduce_scatter: output must be contiguous.");
819+
820+
auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name);
821+
TORCH_CHECK(
822+
symm_mem != nullptr,
823+
"reduce_scatter: input must be allocated with empty_strided_p2p().");
824+
825+
const size_t alignment = get_and_verify_alignment(input, "reduce_scatter");
826+
827+
const size_t output_alignment =
828+
get_and_verify_alignment(input, "reduce_scatter");
829+
830+
TORCH_CHECK(
831+
input.numel() %
832+
(symm_mem->get_world_size() *
833+
(alignment / input.element_size())) ==
834+
0,
835+
"expected number of elements to be divisible by world_size * alignment, number of elements ",
836+
input.numel(),
837+
" world size ",
838+
symm_mem->get_world_size(),
839+
"alignment ",
840+
alignment);
841+
842+
if (split_last_dim) {
843+
TORCH_CHECK(input.dim() == output.dim());
844+
bool are_equal_except_last = std::equal(
845+
input.sizes().begin(), input.sizes().end() - 1, output.sizes().begin());
846+
TORCH_CHECK(
847+
are_equal_except_last,
848+
"reduce_scatter expected input and output to have same sizes except in the last dimension");
849+
TORCH_CHECK(
850+
output.size(-1) == input.size(-1) / symm_mem->get_world_size(),
851+
"reduce_scatter expected output last dim size to be input last dim size / world_size");
852+
853+
TORCH_CHECK(
854+
input.size(-1) %
855+
(symm_mem->get_world_size() *
856+
(alignment / input.element_size())) ==
857+
0,
858+
"expected last dimension to be divisible by world_size * alignment, last dimension ",
859+
input.size(-1),
860+
" world size ",
861+
symm_mem->get_world_size(),
862+
"alignment ",
863+
alignment);
864+
} else {
865+
TORCH_CHECK(input.dim() == 1, "reduce_scatter expected 1D input");
866+
TORCH_CHECK(output.dim() == 1, "reduce_scatter expected 1D output");
867+
TORCH_CHECK(output.numel() == input.numel() / symm_mem->get_world_size());
868+
}
869+
if (input.numel() == 0) {
870+
TORCH_CHECK(input.scalar_type() == output.scalar_type());
871+
return output;
872+
}
873+
874+
TORCH_CHECK(
875+
output_alignment >= alignment,
876+
"reduce_scatter: output alignment should be not smaller than input alignment");
877+
878+
int num_blocks = 0, num_threads = 0;
879+
init_elementwise_launch_config(
880+
input.numel(),
881+
input.element_size(),
882+
alignment,
883+
symm_mem->get_world_size(),
884+
two_shot_all_reduce_max_num_blocks,
885+
two_shot_all_reduce_max_num_threads,
886+
num_blocks,
887+
num_threads);
888+
if (split_last_dim) {
889+
AT_DISPATCH_FLOAT_AND_BFLOAT16(
890+
input.scalar_type(), "two_shot_all_reduce", [&]() {
891+
DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() {
892+
DISPATCH_WORLD_SIZES_NO_DEFAULT(symm_mem->get_world_size(), [&]() {
893+
two_shot_all_reduce_kernel<
894+
scalar_t,
895+
k_alignment,
896+
k_world_size,
897+
true,
898+
true>
899+
<<<num_blocks,
900+
num_threads,
901+
0,
902+
at::cuda::getCurrentCUDAStream()>>>(
903+
reinterpret_cast<scalar_t**>(
904+
symm_mem->get_buffer_ptrs_dev()),
905+
output.data_ptr<scalar_t>(),
906+
input.storage_offset(),
907+
input.numel(),
908+
reinterpret_cast<uint32_t**>(
909+
symm_mem->get_signal_pad_ptrs_dev()),
910+
symm_mem->get_rank(),
911+
symm_mem->get_world_size(),
912+
input.size(-1));
913+
C10_CUDA_KERNEL_LAUNCH_CHECK();
914+
});
915+
});
916+
});
917+
} else {
918+
AT_DISPATCH_FLOAT_AND_BFLOAT16(
919+
input.scalar_type(), "two_shot_all_reduce", [&]() {
920+
DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() {
921+
DISPATCH_WORLD_SIZES_NO_DEFAULT(symm_mem->get_world_size(), [&]() {
922+
two_shot_all_reduce_kernel<
923+
scalar_t,
924+
k_alignment,
925+
k_world_size,
926+
true,
927+
false>
928+
<<<num_blocks,
929+
num_threads,
930+
0,
931+
at::cuda::getCurrentCUDAStream()>>>(
932+
reinterpret_cast<scalar_t**>(
933+
symm_mem->get_buffer_ptrs_dev()),
934+
output.data_ptr<scalar_t>(),
935+
input.storage_offset(),
936+
input.numel(),
937+
reinterpret_cast<uint32_t**>(
938+
symm_mem->get_signal_pad_ptrs_dev()),
939+
symm_mem->get_rank(),
940+
symm_mem->get_world_size(),
941+
input.size(-1));
942+
C10_CUDA_KERNEL_LAUNCH_CHECK();
943+
});
944+
});
945+
});
946+
}
947+
return output;
948+
}
767949
} // namespace
768950
#endif // #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
769951
@@ -899,6 +1081,7 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
8991081
m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out);
9001082
m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_);
9011083
m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out);
1084+
m.impl("reduce_scatter_out", ::reduce_scatter_out);
9021085
9031086
m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm);
9041087
#endif

torch/csrc/distributed/c10d/SymmetricMemory.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
250250
m.def(
251251
"two_shot_all_reduce_out(Tensor(a!) input, str reduce_op, str group_name, Tensor(b!) output) -> Tensor(b!)");
252252

253+
// note this implementation also modified the input tensor
254+
m.def(
255+
"reduce_scatter_out(Tensor(a!) input, str group_name, bool split_last_dim, Tensor(b!) output) -> Tensor(b!)");
256+
253257
// An mm that supports consuming asynchronous input. It guarantees the
254258
// following rasterization order, and that the corresponding signal arrives
255259
// before an input chunk is consumed.

0 commit comments

Comments
 (0)