Skip to content

Commit bb10cac

Browse files
committed
Slightly fancier test for allreduce
1 parent d9b4e16 commit bb10cac

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

test/runtests.jl

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ using Test
99
@test device(comms[i]) == i-1
1010
@test size(comms[i]) == length(CUDAdrv.devices())
1111
end
12-
id = UniqueID()
13-
#=num_devs = length(CUDAdrv.devices())
14-
comm = Communicator(num_devs, id, 0)
15-
@test device(comm) == 0=#
12+
#id = UniqueID()
13+
#num_devs = length(CUDAdrv.devices())
14+
#comm = Communicator(num_devs, id, 0)
15+
#@test device(comm) == 0
1616
end
1717
@testset "Allreduce!" begin
1818
devs = CUDAdrv.devices()
@@ -37,6 +37,32 @@ using Test
3737
crecv = collect(recvbuf[ii])
3838
@test all(crecv .== answer)
3939
end
40+
# more complex example?
41+
recvbuf = Vector{CuMatrix{Float64}}(undef, length(devs))
42+
sendbuf = Vector{CuMatrix{Float64}}(undef, length(devs))
43+
streams = Vector{CuStream}(undef, length(devs))
44+
m = 256
45+
k = 512
46+
n = 256
47+
As = [rand(m, k) for i in 1:length(devs)]
48+
Bs = [rand(k, n) for i in 1:length(devs)]
49+
C = sum(As .* Bs)
50+
for (ii, dev) in enumerate(devs)
51+
CUDAnative.device!(ii - 1)
52+
sendbuf[ii] = cu(As[ii]) * cu(Bs[ii])
53+
recvbuf[ii] = CuArrays.zeros(Float64, m, n)
54+
streams[ii] = CuStream()
55+
end
56+
groupStart()
57+
for ii in 1:length(devs)
58+
Allreduce!(sendbuf[ii], recvbuf[ii], m*n, NCCL.ncclSum, comms[ii], stream=streams[ii])
59+
end
60+
groupEnd()
61+
for (ii, dev) in enumerate(devs)
62+
device!(ii - 1)
63+
crecv = collect(recvbuf[ii])
64+
@test crecv C rtol=1e-6
65+
end
4066
end
4167
@testset "Broadcast!" begin
4268
devs = CUDAdrv.devices()

0 commit comments

Comments
 (0)