Skip to content

Commit 64db9f7

Browse files
authored
[Runtime] Introduce MSCCLPP with NCCL equivalent interface (#16804)
* [Runtime] Introduce MSCCLPP with NCCL equivalent interface * Add a fast and simple AllReduce kernel (sum only) using using mscclpp smChannel scratch for small reductions up to 2**24 bytes.
1 parent 3ce87cb commit 64db9f7

File tree

6 files changed

+1161
-2
lines changed

6 files changed

+1161
-2
lines changed

3rdparty/mscclpp/include/common.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#ifndef MSCCL_COMMON_HPP_
5+
#define MSCCL_COMMON_HPP_
6+
7+
#if defined(__HIP_PLATFORM_AMD__)
8+
#define WARP_SIZE 64
9+
#define __syncwarp() __builtin_amdgcn_wave_barrier()
10+
#else
11+
#define WARP_SIZE 32
12+
#endif
13+
14+
constexpr int NRANKS_PER_NODE = 8;
15+
constexpr int SCRATCH_SIZE = 1024 * 1024 * 70; // 35 thread-blocks * 8 ranks * 256KB = 70MB
16+
17+
template <typename To, typename From>
18+
__forceinline__ __device__ To bit_cast(const From& src) {
19+
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
20+
21+
union {
22+
From f;
23+
To t;
24+
} u;
25+
u.f = src;
26+
return u.t;
27+
}
28+
29+
template <typename T>
30+
__forceinline__ __device__ T add_elements(T a, T b) {
31+
return a + b;
32+
}
33+
34+
template <>
35+
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
36+
return __hadd2(a, b);
37+
}
38+
39+
template <typename T>
40+
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
41+
int4 ret;
42+
ret.w = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
43+
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
44+
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
45+
ret.z = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
46+
return ret;
47+
}
48+
49+
template <typename T>
50+
__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) {
51+
return add_vectors_helper<T>(a, b);
52+
}
53+
54+
template <>
55+
__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
56+
return add_vectors_helper<__half2>(a, b);
57+
}
58+
59+
template <typename T>
60+
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
61+
uint2 ret;
62+
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
63+
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
64+
return ret;
65+
}
66+
67+
template <typename T>
68+
__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) {
69+
return add_vectors_helper<T>(a, b);
70+
}
71+
72+
template <>
73+
__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
74+
return add_vectors_helper<__half2>(a, b);
75+
}
76+
77+
template <typename T>
78+
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
79+
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
80+
}
81+
82+
template <typename T>
83+
__forceinline__ __device__ int add_vectors(int a, int b) {
84+
return add_vectors_helper<T>(a, b);
85+
}
86+
87+
template <>
88+
__forceinline__ __device__ int add_vectors<__half>(int a, int b) {
89+
return add_vectors_helper<__half2>(a, b);
90+
}
91+
92+
template <typename T>
93+
__forceinline__ __device__ uint32_t add_vectors_helper(uint32_t a, uint32_t b) {
94+
return bit_cast<uint32_t, T>(add_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
95+
}
96+
97+
template <typename T>
98+
__forceinline__ __device__ uint32_t add_vectors(uint32_t a, uint32_t b) {
99+
return add_vectors_helper<T>(a, b);
100+
}
101+
102+
template <>
103+
__forceinline__ __device__ uint32_t add_vectors<__half>(uint32_t a, uint32_t b) {
104+
return add_vectors_helper<__half2>(a, b);
105+
}
106+
107+
#endif // MSCCL_COMMON_HPP_

0 commit comments

Comments
 (0)