@@ -9,10 +9,10 @@ using Test
9
9
@test device (comms[i]) == i- 1
10
10
@test size (comms[i]) == length (CUDAdrv. devices ())
11
11
end
12
- id = UniqueID ()
13
- #= num_devs = length(CUDAdrv.devices())
14
- comm = Communicator(num_devs, id, 0)
15
- @test device(comm) == 0=#
12
+ # id = UniqueID()
13
+ # num_devs = length(CUDAdrv.devices())
14
+ # comm = Communicator(num_devs, id, 0)
15
+ # @test device(comm) == 0
16
16
end
17
17
@testset " Allreduce!" begin
18
18
devs = CUDAdrv. devices ()
@@ -37,6 +37,32 @@ using Test
37
37
crecv = collect (recvbuf[ii])
38
38
@test all (crecv .== answer)
39
39
end
40
+ # more complex example?
41
+ recvbuf = Vector {CuMatrix{Float64}} (undef, length (devs))
42
+ sendbuf = Vector {CuMatrix{Float64}} (undef, length (devs))
43
+ streams = Vector {CuStream} (undef, length (devs))
44
+ m = 256
45
+ k = 512
46
+ n = 256
47
+ As = [rand (m, k) for i in 1 : length (devs)]
48
+ Bs = [rand (k, n) for i in 1 : length (devs)]
49
+ C = sum (As .* Bs)
50
+ for (ii, dev) in enumerate (devs)
51
+ CUDAnative. device! (ii - 1 )
52
+ sendbuf[ii] = cu (As[ii]) * cu (Bs[ii])
53
+ recvbuf[ii] = CuArrays. zeros (Float64, m, n)
54
+ streams[ii] = CuStream ()
55
+ end
56
+ groupStart ()
57
+ for ii in 1 : length (devs)
58
+ Allreduce! (sendbuf[ii], recvbuf[ii], m* n, NCCL. ncclSum, comms[ii], stream= streams[ii])
59
+ end
60
+ groupEnd ()
61
+ for (ii, dev) in enumerate (devs)
62
+ device! (ii - 1 )
63
+ crecv = collect (recvbuf[ii])
64
+ @test crecv ≈ C rtol= 1e-6
65
+ end
40
66
end
41
67
@testset " Broadcast!" begin
42
68
devs = CUDAdrv. devices ()
0 commit comments