Skip to content

Commit d9b4e16

Browse files
authored
Merge pull request #4 from JuliaGPU/tb/devcount
Device-count agnostic tests
2 parents 39d28f1 + 413ee7a commit d9b4e16

File tree

3 files changed

+193
-10
lines changed

3 files changed

+193
-10
lines changed

Manifest.toml

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
[[AbstractFFTs]]
4+
deps = ["LinearAlgebra"]
5+
git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40"
6+
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
7+
version = "0.4.1"
8+
9+
[[Adapt]]
10+
deps = ["LinearAlgebra"]
11+
git-tree-sha1 = "82dab828020b872fa9efd3abec1152b075bc7cbf"
12+
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
13+
version = "1.0.0"
14+
15+
[[Base64]]
16+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
17+
18+
[[CEnum]]
19+
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
20+
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
21+
version = "0.2.0"
22+
23+
[[CUDAapi]]
24+
deps = ["Libdl", "Logging"]
25+
git-tree-sha1 = "6eee47385c81ed3b3f716b745697869c712c2df3"
26+
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
27+
version = "2.0.0"
28+
29+
[[CUDAdrv]]
30+
deps = ["CEnum", "CUDAapi", "Printf"]
31+
git-tree-sha1 = "0f39fddace3324707469ace7fbcbc7b28d5cf921"
32+
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
33+
version = "4.0.4"
34+
35+
[[CUDAnative]]
36+
deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"]
37+
git-tree-sha1 = "93f6c917ab2a9b5bb54f8f738f4ec1a6693cb716"
38+
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
39+
version = "2.5.5"
40+
41+
[[Compat]]
42+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
43+
git-tree-sha1 = "ed2c4abadf84c53d9e58510b5fc48912c2336fbb"
44+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
45+
version = "2.2.0"
46+
47+
[[CuArrays]]
48+
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
49+
git-tree-sha1 = "4757376a85ffb27d4c4f6cdf9635261e6c3a5fec"
50+
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
51+
version = "1.4.7"
52+
53+
[[DataStructures]]
54+
deps = ["InteractiveUtils", "OrderedCollections"]
55+
git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10"
56+
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
57+
version = "0.17.6"
58+
59+
[[Dates]]
60+
deps = ["Printf"]
61+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
62+
63+
[[DelimitedFiles]]
64+
deps = ["Mmap"]
65+
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
66+
67+
[[Distributed]]
68+
deps = ["Random", "Serialization", "Sockets"]
69+
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
70+
71+
[[GPUArrays]]
72+
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
73+
git-tree-sha1 = "a0a3b927b1a06e63fb8b91950cc7df340b7d912c"
74+
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
75+
version = "2.0.0"
76+
77+
[[InteractiveUtils]]
78+
deps = ["Markdown"]
79+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
80+
81+
[[LLVM]]
82+
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
83+
git-tree-sha1 = "74fe444b8b6d1ac01d639b2f9eaf395bcc2e24fc"
84+
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
85+
version = "1.3.2"
86+
87+
[[LibGit2]]
88+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
89+
90+
[[Libdl]]
91+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
92+
93+
[[LinearAlgebra]]
94+
deps = ["Libdl"]
95+
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
96+
97+
[[Logging]]
98+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
99+
100+
[[MacroTools]]
101+
deps = ["Compat", "DataStructures", "Test"]
102+
git-tree-sha1 = "82921f0e3bde6aebb8e524efc20f4042373c0c06"
103+
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
104+
version = "0.5.2"
105+
106+
[[Markdown]]
107+
deps = ["Base64"]
108+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
109+
110+
[[Mmap]]
111+
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
112+
113+
[[NNlib]]
114+
deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"]
115+
git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8"
116+
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
117+
version = "0.6.0"
118+
119+
[[OrderedCollections]]
120+
deps = ["Random", "Serialization", "Test"]
121+
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
122+
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
123+
version = "1.1.0"
124+
125+
[[Pkg]]
126+
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
127+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
128+
129+
[[Printf]]
130+
deps = ["Unicode"]
131+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
132+
133+
[[REPL]]
134+
deps = ["InteractiveUtils", "Markdown", "Sockets"]
135+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
136+
137+
[[Random]]
138+
deps = ["Serialization"]
139+
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
140+
141+
[[Requires]]
142+
deps = ["Test"]
143+
git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
144+
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
145+
version = "0.5.2"
146+
147+
[[SHA]]
148+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
149+
150+
[[Serialization]]
151+
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
152+
153+
[[SharedArrays]]
154+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
155+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
156+
157+
[[Sockets]]
158+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
159+
160+
[[SparseArrays]]
161+
deps = ["LinearAlgebra", "Random"]
162+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
163+
164+
[[Statistics]]
165+
deps = ["LinearAlgebra", "SparseArrays"]
166+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
167+
168+
[[Test]]
169+
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
170+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
171+
172+
[[TimerOutputs]]
173+
deps = ["Printf"]
174+
git-tree-sha1 = "311765af81bbb48d7bad01fb016d9c328c6ede03"
175+
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
176+
version = "0.5.3"
177+
178+
[[UUIDs]]
179+
deps = ["Random", "SHA"]
180+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
181+
182+
[[Unicode]]
183+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

src/NCCL.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ function __init__()
3535
end
3636

3737
try
38-
if version() < v"2"
39-
silent || @warn "NCCL.jl only supports NCCL 2.x (you are using $(version()))"
38+
if version() < v"2.4"
39+
silent || @warn "NCCL.jl only supports NCCL 2.4 or higher (you are using $(version()))"
4040
end
4141
catch ex
4242
# don't actually fail to keep the package loadable

test/runtests.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ using Test
44
@testset "NCCL.jl" begin
55
@testset "Communicator" begin
66
comms = Communicator(CUDAdrv.devices())
7-
@test device(comms[1]) == 0
8-
@test device(comms[2]) == 1
9-
@test size(comms[1]) == 2
10-
@test rank(comms[1]) == 0
11-
@test rank(comms[2]) == 1
7+
for (i,dev) in enumerate(CUDAdrv.devices())
8+
@test rank(comms[i]) == i-1
9+
@test device(comms[i]) == i-1
10+
@test size(comms[i]) == length(CUDAdrv.devices())
11+
end
1212
id = UniqueID()
1313
#=num_devs = length(CUDAdrv.devices())
1414
comm = Communicator(num_devs, id, 0)
@@ -31,7 +31,7 @@ using Test
3131
Allreduce!(sendbuf[ii], recvbuf[ii], 512, NCCL.ncclSum, comms[ii], stream=streams[ii])
3232
end
3333
groupEnd()
34-
answer = sum(1:length(devs))
34+
answer = sum(1:length(devs))
3535
for (ii, dev) in enumerate(devs)
3636
device!(ii - 1)
3737
crecv = collect(recvbuf[ii])
@@ -105,7 +105,7 @@ using Test
105105
Allgather!(sendbuf[ii], recvbuf[ii], 512, comms[ii], stream=streams[ii])
106106
end
107107
groupEnd()
108-
answer = vec(repeat(1:length(devs), inner=512))
108+
answer = vec(repeat(1:length(devs), inner=512))
109109
for (ii, dev) in enumerate(devs)
110110
device!(ii - 1)
111111
crecv = collect(recvbuf[ii])
@@ -127,7 +127,7 @@ using Test
127127
end
128128
groupStart()
129129
for ii in 1:length(devs)
130-
ReduceScatter!(sendbuf[ii], recvbuf[ii], 2, NCCL.ncclSum, comms[ii], stream=streams[ii])
130+
ReduceScatter!(sendbuf[ii], recvbuf[ii], length(devs), NCCL.ncclSum, comms[ii], stream=streams[ii])
131131
end
132132
groupEnd()
133133
for (ii, dev) in enumerate(devs)

0 commit comments

Comments
 (0)