diff --git a/src/base.jl b/src/base.jl index cb1efc9..65c03cc 100644 --- a/src/base.jl +++ b/src/base.jl @@ -53,3 +53,4 @@ ncclDataType_t(::Type{UInt64}) = ncclUint64 ncclDataType_t(::Type{Float16}) = ncclFloat16 ncclDataType_t(::Type{Float32}) = ncclFloat32 ncclDataType_t(::Type{Float64}) = ncclFloat64 +ncclDataType_t(::Type{Complex{T}}) where {T} = ncclDataType_t(T) diff --git a/src/collective.jl b/src/collective.jl index 054fa3a..1608cb6 100644 --- a/src/collective.jl +++ b/src/collective.jl @@ -1,3 +1,6 @@ +count(X::CuArray{T}) where {T} = length(X) +count(X::CuArray{Complex{T}}) where {T} = 2*length(X) + """ NCCL.Allreduce!( sendbuf, recvbuf, op, comm::Communicator; @@ -11,11 +14,11 @@ or [`NCCL.avg`](@ref)), writing the result to `recvbuf` to all ranks. """ function Allreduce!(sendbuf, recvbuf, op, comm::Communicator; stream::CuStream=default_device_stream(comm)) - count = length(recvbuf) - @assert length(sendbuf) == count + a_count = count(recvbuf) + @assert count(sendbuf) == a_count data_type = ncclDataType_t(eltype(recvbuf)) _op = ncclRedOp_t(op) - ncclAllReduce(sendbuf, recvbuf, count, data_type, _op, comm, stream) + ncclAllReduce(sendbuf, recvbuf, a_count, data_type, _op, comm, stream) return recvbuf end @@ -47,8 +50,8 @@ Copies array the `sendbuf` on rank `root` to `recvbuf` on all ranks. function Broadcast!(sendbuf, recvbuf, comm::Communicator; root::Integer=0, stream::CuStream=default_device_stream(comm)) data_type = ncclDataType_t(eltype(recvbuf)) - count = length(recvbuf) - ncclBroadcast(sendbuf, recvbuf, count, data_type, root, comm, stream) + a_count = count(recvbuf) + ncclBroadcast(sendbuf, recvbuf, a_count, data_type, root, comm, stream) return recvbuf end function Broadcast!(sendrecvbuf, comm::Communicator; root::Integer=0, @@ -72,9 +75,9 @@ or `[`NCCL.avg`](@ref)`), writing the result to `recvbuf` on rank `root`. function Reduce!(sendbuf, recvbuf, op, comm::Communicator; root::Integer=0, stream::CuStream=default_device_stream(comm)) data_type = ncclDataType_t(eltype(recvbuf)) - count = length(recvbuf) + a_count = count(recvbuf) _op = ncclRedOp_t(op) - ncclReduce(sendbuf, recvbuf, count, data_type, _op, root, comm, stream) + ncclReduce(sendbuf, recvbuf, a_count, data_type, _op, root, comm, stream) return recvbuf end function Reduce!(sendrecvbuf, op, comm::Communicator; root::Integer=0, @@ -96,9 +99,9 @@ Concatenate `sendbuf` from each rank into `recvbuf` on all ranks. function Allgather!(sendbuf, recvbuf, comm::Communicator; stream::CuStream=default_device_stream(comm)) data_type = ncclDataType_t(eltype(recvbuf)) - sendcount = length(sendbuf) - @assert length(recvbuf) == sendcount * size(comm) - ncclAllGather(sendbuf, recvbuf, sendcount, data_type, comm, stream) + senda_count = count(sendbuf) + @assert count(recvbuf) == senda_count * size(comm) + ncclAllGather(sendbuf, recvbuf, senda_count, data_type, comm, stream) return recvbuf end @@ -117,10 +120,10 @@ scattered over the devices such that `recvbuf` on each rank will contain the """ function ReduceScatter!(sendbuf, recvbuf, op, comm::Communicator; stream::CuStream=default_device_stream(comm)) - recvcount = length(recvbuf) - @assert length(sendbuf) == recvcount * size(comm) + recva_count = count(recvbuf) + @assert count(sendbuf) == recva_count * size(comm) data_type = ncclDataType_t(eltype(recvbuf)) _op = ncclRedOp_t(op) - ncclReduceScatter(sendbuf, recvbuf, recvcount, data_type, _op, comm, stream) + ncclReduceScatter(sendbuf, recvbuf, recva_count, data_type, _op, comm, stream) return recvbuf end diff --git a/src/pointtopoint.jl b/src/pointtopoint.jl index c0d9544..b712dee 100644 --- a/src/pointtopoint.jl +++ b/src/pointtopoint.jl @@ -13,9 +13,9 @@ called. """ function Send(sendbuf, comm::Communicator; dest::Integer, stream::CuStream=default_device_stream(comm)) - count = length(sendbuf) + a_count = count(sendbuf) datatype = ncclDataType_t(eltype(sendbuf)) - ncclSend(sendbuf, count, datatype, dest, comm, stream) + ncclSend(sendbuf, a_count, datatype, dest, comm, stream) return nothing end @@ -34,8 +34,8 @@ Write the data from a matching [`Send`](@ref) on rank `source` into `recvbuf`. """ function Recv!(recvbuf, comm::Communicator; source::Integer, stream::CuStream=default_device_stream(comm)) - count = length(recvbuf) + a_count = count(recvbuf) datatype = ncclDataType_t(eltype(recvbuf)) - ncclRecv(recvbuf, count, datatype, source, comm, stream) + ncclRecv(recvbuf, a_count, datatype, source, comm, stream) return recvbuf.data end diff --git a/test/runtests.jl b/test/runtests.jl index df695c0..1dc2dfb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,171 +26,187 @@ end devs = CUDA.devices() comms = NCCL.Communicators(devs) - @testset "sum" begin - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - N = 512 + @testset "$T" for T in (Float64, ComplexF64) + @testset "sum" begin + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), N)) + recvbuf[ii] = CUDA.zeros(T, N) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + end + end + answer = sum(1:length(devs)) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) + end + end + + @testset "NCCL.avg" begin + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), N)) + recvbuf[ii] = CUDA.zeros(T, N) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii]) + end + end + answer = sum(1:length(devs)) / length(devs) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .≈ answer) + end + end + end +end + +@testset "Broadcast!" begin + devs = CUDA.devices() + comms = NCCL.Communicators(devs) + + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + root = 0 for (ii, dev) in enumerate(devs) CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), N)) - recvbuf[ii] = CUDA.zeros(Float64, N) + sendbuf[ii] = (ii - 1) == root ? CuArray(fill(T(1.0), 512)) : CUDA.zeros(T, 512) + recvbuf[ii] = CUDA.zeros(T, 512) end NCCL.group() do for ii in 1:length(devs) - NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + NCCL.Broadcast!(sendbuf[ii], recvbuf[ii], comms[ii]; root) end end - answer = sum(1:length(devs)) + answer = 1.0 for (ii, dev) in enumerate(devs) device!(ii - 1) crecv = collect(recvbuf[ii]) @test all(crecv .== answer) end end +end - @testset "NCCL.avg" begin - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - N = 512 +@testset "Reduce!" begin + devs = CUDA.devices() + comms = NCCL.Communicators(devs) + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + root = 0 for (ii, dev) in enumerate(devs) CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), N)) - recvbuf[ii] = CUDA.zeros(Float64, N) + sendbuf[ii] = CuArray(fill(T(ii), 512)) + recvbuf[ii] = CUDA.zeros(T, 512) end NCCL.group() do for ii in 1:length(devs) - NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii]) + NCCL.Reduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]; root) end end - answer = sum(1:length(devs)) / length(devs) for (ii, dev) in enumerate(devs) + answer = (ii - 1) == root ? sum(1:length(devs)) : 0.0 device!(ii - 1) crecv = collect(recvbuf[ii]) - @test all(crecv .≈ answer) + @test all(crecv .== answer) end end end -@testset "Broadcast!" begin +@testset "Allgather!" begin devs = CUDA.devices() comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - root = 0 - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = (ii - 1) == root ? CuArray(fill(Float64(1.0), 512)) : CUDA.zeros(Float64, 512) - recvbuf[ii] = CUDA.zeros(Float64, 512) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.Broadcast!(sendbuf[ii], recvbuf[ii], comms[ii]; root) - end - end - answer = 1.0 - for (ii, dev) in enumerate(devs) - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) - end -end -@testset "Reduce!" begin - devs = CUDA.devices() - comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - root = 0 - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), 512)) - recvbuf[ii] = CUDA.zeros(Float64, 512) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.Reduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]; root) + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), 512)) + recvbuf[ii] = CUDA.zeros(T, length(devs)*512) end - end - for (ii, dev) in enumerate(devs) - answer = (ii - 1) == root ? sum(1:length(devs)) : 0.0 - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) - end -end - -@testset "Allgather!" begin - devs = CUDA.devices() - comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), 512)) - recvbuf[ii] = CUDA.zeros(Float64, length(devs)*512) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.Allgather!(sendbuf[ii], recvbuf[ii], comms[ii]) + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allgather!(sendbuf[ii], recvbuf[ii], comms[ii]) + end + end + answer = vec(repeat(1:length(devs), inner=512)) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) end - end - answer = vec(repeat(1:length(devs), inner=512)) - for (ii, dev) in enumerate(devs) - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) end end @testset "ReduceScatter!" begin devs = CUDA.devices() comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(vec(repeat(collect(1:length(devs)), inner=2))) - recvbuf[ii] = CUDA.zeros(Float64, 2) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.ReduceScatter!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(vec(repeat(collect(1:length(devs)), inner=2))) + recvbuf[ii] = CUDA.zeros(T, 2) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.ReduceScatter!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + end + end + for (ii, dev) in enumerate(devs) + answer = length(devs)*ii + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) end - end - for (ii, dev) in enumerate(devs) - answer = length(devs)*ii - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) end end @testset "Send/Recv" begin devs = CUDA.devices() comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - N = 512 - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), N)) - recvbuf[ii] = CUDA.zeros(Float64, N) - end - NCCL.group() do - for ii in 1:length(devs) - comm = comms[ii] - dest = mod(NCCL.rank(comm)+1, NCCL.size(comm)) - source = mod(NCCL.rank(comm)-1, NCCL.size(comm)) - NCCL.Send(sendbuf[ii], comm; dest) - NCCL.Recv!(recvbuf[ii], comm; source) + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), N)) + recvbuf[ii] = CUDA.zeros(T, N) + end + + NCCL.group() do + for ii in 1:length(devs) + comm = comms[ii] + dest = mod(NCCL.rank(comm)+1, NCCL.size(comm)) + source = mod(NCCL.rank(comm)-1, NCCL.size(comm)) + NCCL.Send(sendbuf[ii], comm; dest) + NCCL.Recv!(recvbuf[ii], comm; source) + end + end + for (ii, dev) in enumerate(devs) + answer = mod1(ii - 1, length(devs)) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) end - end - for (ii, dev) in enumerate(devs) - answer = mod1(ii - 1, length(devs)) - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) end end