diff --git a/Project.toml b/Project.toml index 25456f5..8fa3e11 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NCCL" uuid = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" -version = "0.1.1" +version = "0.1.2" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/res/wrap/Manifest.toml b/res/wrap/Manifest.toml index 556b402..f9b657d 100644 --- a/res/wrap/Manifest.toml +++ b/res/wrap/Manifest.toml @@ -1,92 +1,65 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.4" +julia_version = "1.11.6" manifest_format = "2.0" project_hash = "c24c6cc39efbdedbfd117abf2a3fb1c76d6d6317" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" +version = "1.1.2" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" [[deps.CEnum]] -git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.2" - -[[deps.CSTParser]] -deps = ["Tokenize"] -git-tree-sha1 = "3ddd48d200eb8ddf9cb3e0189fc059fd49b97c1f" -uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" -version = "3.3.6" +version = "0.5.0" [[deps.CUDA_Driver_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "1e42ef1bdb45487ff28de16182c0df4920181dc3" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "12621de83838b5ce6a185050db5a184f4540679b" uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.7.0+0" +version = "13.0.0+0" [[deps.CUDA_Runtime_jll]] deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "9704e50c9158cf8896c2776b8dbc5edd136caf80" +git-tree-sha1 = "cc727d90c9769db27945219f9ba149dbddc74f06" uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.10.1+0" +version = "0.19.0+0" [[deps.Clang]] deps = ["CEnum", "Clang_jll", "Downloads", "Pkg", "TOML"] -git-tree-sha1 = "17e51c980e1797210f30155d18ea5a2f02b80833" +git-tree-sha1 = "6bf09af911d9df84656ddd0c02b5e449c18adba8" uuid = "40e3b903-d033-50b4-a0cc-940c62c95e31" -version = "0.17.7" +version = "0.19.0" [[deps.Clang_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "TOML", "Zlib_jll", "libLLVM_jll"] -git-tree-sha1 = "124bb00d4ceace456054f17c7cb01e5c8195c609" +deps = ["Artifacts", "JLLWrappers", "Libdl", "TOML", "Zlib_jll", "libLLVM_jll"] +git-tree-sha1 = "ebfc8e89823ec2c85ed9fedabe52db149da4c5ec" uuid = "0ee61d77-7f21-5576-8119-9fcc46b10100" -version = "14.0.6+4" +version = "16.0.6+5" [[deps.CommonMark]] -deps = ["Crayons", "JSON", "PrecompileTools", "URIs"] -git-tree-sha1 = "532c4185d3c9037c0237546d817858b23cf9e071" +deps = ["PrecompileTools"] +git-tree-sha1 = "351d6f4eaf273b753001b2de4dffb8279b100769" uuid = "a80b9123-70ca-4bc0-993e-6e3bcb318db6" -version = "0.8.12" - -[[deps.Compat]] -deps = ["UUIDs"] -git-tree-sha1 = "886826d76ea9e72b35fcd000e535588f7b60f21d" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.10.1" - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - - [deps.Compat.weakdeps] - Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" - LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "0.9.1" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+0" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.15" +version = "1.1.1+0" [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -95,37 +68,34 @@ version = "1.6.0" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +version = "1.11.0" [[deps.Glob]] git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" version = "1.3.1" -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +git-tree-sha1 = "0533e564aae234aff59ab625543145446d8b6ec2" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" +version = "1.7.1" [[deps.JuliaFormatter]] -deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "PrecompileTools", "Tokenize"] -git-tree-sha1 = "8f5295e46f594ad2d8652f1098488a77460080cd" +deps = ["CommonMark", "Glob", "JuliaSyntax", "PrecompileTools", "TOML"] +git-tree-sha1 = "f512fefd5fdc7dd1ca05778f08f91e9e4c9fdc37" uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" -version = "1.0.45" +version = "2.1.6" + +[[deps.JuliaSyntax]] +git-tree-sha1 = "937da4713526b96ac9a178e2035019d3b78ead4a" +uuid = "70703baa-626e-46a2-a12c-08ffd08c73b4" +version = "0.4.10" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +version = "1.11.0" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] @@ -135,11 +105,17 @@ version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" +version = "8.6.0+0" [[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +version = "1.11.0" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.7.2+0" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] @@ -148,86 +124,73 @@ version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "2.28.6+0" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" +version = "2023.12.12" [[deps.NCCL_jll]] deps = ["Artifacts", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "9dcc505d4c7267cb4e1904b3c3f8049ef61d0b4a" +git-tree-sha1 = "fe98fcd75fc82d1291dde2deafabecc9a2135190" uuid = "4d6d38e4-5b87-5e63-912a-873ff2d649b7" -version = "2.19.4+0" +version = "2.26.5+1" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" -[[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "a935806434c9d4c506ba941871b327b96d41f2bf" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.0" - [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.2" +version = "1.11.0" + + [deps.Pkg.extensions] + REPLExt = "REPL" + + [deps.Pkg.weakdeps] + REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[deps.PrecompileTools]] deps = ["Preferences"] -git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.0" +version = "1.2.1" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e" +git-tree-sha1 = "0f27480397253da18fe2c12a4ba4eb9eb208bf3d" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.1" +version = "1.5.0" [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" +version = "1.11.0" [[deps.Random]] -deps = ["SHA", "Serialization"] +deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" @@ -238,39 +201,31 @@ deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" -[[deps.Tokenize]] -git-tree-sha1 = "0454d9a9bad2400c7ccad19ca832a2ef5a8bc3a1" -uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" -version = "0.5.26" - -[[deps.URIs]] -git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.1" - [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" +version = "1.2.13+1" [[deps.libLLVM_jll]] deps = ["Artifacts", "Libdl"] uuid = "8f36deef-c2a5-5394-99ed-8e07531fb29a" -version = "14.0.6+3" +version = "16.0.6+5" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" +version = "1.59.0+0" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" +version = "17.4.0+2" diff --git a/res/wrap/prologue.jl b/res/wrap/prologue.jl index 157752d..51701a2 100644 --- a/res/wrap/prologue.jl +++ b/res/wrap/prologue.jl @@ -1,8 +1,7 @@ const NULL = C_NULL const INT_MIN = typemin(Cint) -import CUDA.APIUtils: @checked -import CUDA: CuPtr, CUstream +import CUDA: CuPtr, CUstream, @checked function check(f) res = f()::ncclResult_t diff --git a/src/base.jl b/src/base.jl index cb1efc9..65c03cc 100644 --- a/src/base.jl +++ b/src/base.jl @@ -53,3 +53,4 @@ ncclDataType_t(::Type{UInt64}) = ncclUint64 ncclDataType_t(::Type{Float16}) = ncclFloat16 ncclDataType_t(::Type{Float32}) = ncclFloat32 ncclDataType_t(::Type{Float64}) = ncclFloat64 +ncclDataType_t(::Type{Complex{T}}) where {T} = ncclDataType_t(T) diff --git a/src/collective.jl b/src/collective.jl index 054fa3a..1608cb6 100644 --- a/src/collective.jl +++ b/src/collective.jl @@ -1,3 +1,6 @@ +count(X::CuArray{T}) where {T} = length(X) +count(X::CuArray{Complex{T}}) where {T} = 2*length(X) + """ NCCL.Allreduce!( sendbuf, recvbuf, op, comm::Communicator; @@ -11,11 +14,11 @@ or [`NCCL.avg`](@ref)), writing the result to `recvbuf` to all ranks. """ function Allreduce!(sendbuf, recvbuf, op, comm::Communicator; stream::CuStream=default_device_stream(comm)) - count = length(recvbuf) - @assert length(sendbuf) == count + a_count = count(recvbuf) + @assert count(sendbuf) == a_count data_type = ncclDataType_t(eltype(recvbuf)) _op = ncclRedOp_t(op) - ncclAllReduce(sendbuf, recvbuf, count, data_type, _op, comm, stream) + ncclAllReduce(sendbuf, recvbuf, a_count, data_type, _op, comm, stream) return recvbuf end @@ -47,8 +50,8 @@ Copies array the `sendbuf` on rank `root` to `recvbuf` on all ranks. function Broadcast!(sendbuf, recvbuf, comm::Communicator; root::Integer=0, stream::CuStream=default_device_stream(comm)) data_type = ncclDataType_t(eltype(recvbuf)) - count = length(recvbuf) - ncclBroadcast(sendbuf, recvbuf, count, data_type, root, comm, stream) + a_count = count(recvbuf) + ncclBroadcast(sendbuf, recvbuf, a_count, data_type, root, comm, stream) return recvbuf end function Broadcast!(sendrecvbuf, comm::Communicator; root::Integer=0, @@ -72,9 +75,9 @@ or `[`NCCL.avg`](@ref)`), writing the result to `recvbuf` on rank `root`. function Reduce!(sendbuf, recvbuf, op, comm::Communicator; root::Integer=0, stream::CuStream=default_device_stream(comm)) data_type = ncclDataType_t(eltype(recvbuf)) - count = length(recvbuf) + a_count = count(recvbuf) _op = ncclRedOp_t(op) - ncclReduce(sendbuf, recvbuf, count, data_type, _op, root, comm, stream) + ncclReduce(sendbuf, recvbuf, a_count, data_type, _op, root, comm, stream) return recvbuf end function Reduce!(sendrecvbuf, op, comm::Communicator; root::Integer=0, @@ -96,9 +99,9 @@ Concatenate `sendbuf` from each rank into `recvbuf` on all ranks. function Allgather!(sendbuf, recvbuf, comm::Communicator; stream::CuStream=default_device_stream(comm)) data_type = ncclDataType_t(eltype(recvbuf)) - sendcount = length(sendbuf) - @assert length(recvbuf) == sendcount * size(comm) - ncclAllGather(sendbuf, recvbuf, sendcount, data_type, comm, stream) + senda_count = count(sendbuf) + @assert count(recvbuf) == senda_count * size(comm) + ncclAllGather(sendbuf, recvbuf, senda_count, data_type, comm, stream) return recvbuf end @@ -117,10 +120,10 @@ scattered over the devices such that `recvbuf` on each rank will contain the """ function ReduceScatter!(sendbuf, recvbuf, op, comm::Communicator; stream::CuStream=default_device_stream(comm)) - recvcount = length(recvbuf) - @assert length(sendbuf) == recvcount * size(comm) + recva_count = count(recvbuf) + @assert count(sendbuf) == recva_count * size(comm) data_type = ncclDataType_t(eltype(recvbuf)) _op = ncclRedOp_t(op) - ncclReduceScatter(sendbuf, recvbuf, recvcount, data_type, _op, comm, stream) + ncclReduceScatter(sendbuf, recvbuf, recva_count, data_type, _op, comm, stream) return recvbuf end diff --git a/src/libnccl.jl b/src/libnccl.jl index 16c36ba..1f776af 100644 --- a/src/libnccl.jl +++ b/src/libnccl.jl @@ -3,17 +3,17 @@ module LibNCCL using NCCL_jll export NCCL_jll -using CEnum +using CEnum: CEnum, @cenum const NULL = C_NULL const INT_MIN = typemin(Cint) -import CUDA: @checked, CuPtr, CUstream +import CUDA: CuPtr, CUstream, @checked function check(f) res = f()::ncclResult_t if res != ncclSuccess - throw(NCCLError(res)) + throw(NCCLError(err)) end return end @@ -28,10 +28,20 @@ struct ncclConfig_v21700 maxCTAs::Cint netName::Cstring splitShare::Cint + trafficClass::Cint end const ncclConfig_t = ncclConfig_v21700 +struct ncclSimInfo_v22200 + size::Cint + magic::Cuint + version::Cuint + estimatedTime::Cfloat +end + +const ncclSimInfo_t = ncclSimInfo_v22200 + mutable struct ncclComm end const ncclComm_t = Ptr{ncclComm} @@ -152,6 +162,20 @@ end config::Ptr{ncclConfig_t})::ncclResult_t end +@checked function ncclCommInitRankScalable(newcomm, nranks, myrank, nId, commIds, config) + @ccall libnccl.ncclCommInitRankScalable(newcomm::Ptr{ncclComm_t}, nranks::Cint, + myrank::Cint, nId::Cint, + commIds::Ptr{ncclUniqueId}, + config::Ptr{ncclConfig_t})::ncclResult_t +end + +@checked function pncclCommInitRankScalable(newcomm, nranks, myrank, nId, commIds, config) + @ccall libnccl.pncclCommInitRankScalable(newcomm::Ptr{ncclComm_t}, nranks::Cint, + myrank::Cint, nId::Cint, + commIds::Ptr{ncclUniqueId}, + config::Ptr{ncclConfig_t})::ncclResult_t +end + function ncclGetErrorString(result) @ccall libnccl.ncclGetErrorString(result::ncclResult_t)::Cstring end @@ -168,6 +192,16 @@ function pncclGetLastError(comm) @ccall libnccl.pncclGetLastError(comm::ncclComm_t)::Cstring end +# no prototype is found for this function at nccl.h:192:7, please use with caution +function ncclResetDebugInit() + @ccall libnccl.ncclResetDebugInit()::Cvoid +end + +# no prototype is found for this function at nccl.h:193:6, please use with caution +function pncclResetDebugInit() + @ccall libnccl.pncclResetDebugInit()::Cvoid +end + @checked function ncclCommGetAsyncError(comm, asyncError) @ccall libnccl.ncclCommGetAsyncError(comm::ncclComm_t, asyncError::Ptr{ncclResult_t})::ncclResult_t @@ -202,6 +236,24 @@ end @ccall libnccl.pncclCommUserRank(comm::ncclComm_t, rank::Ptr{Cint})::ncclResult_t end +@checked function ncclCommRegister(comm, buff, size, handle) + @ccall libnccl.ncclCommRegister(comm::ncclComm_t, buff::CuPtr{Cvoid}, size::Cint, + handle::Ptr{Ptr{Cvoid}})::ncclResult_t +end + +@checked function pncclCommRegister(comm, buff, size, handle) + @ccall libnccl.pncclCommRegister(comm::ncclComm_t, buff::CuPtr{Cvoid}, size::Cint, + handle::Ptr{Ptr{Cvoid}})::ncclResult_t +end + +@checked function ncclCommDeregister(comm, handle) + @ccall libnccl.ncclCommDeregister(comm::ncclComm_t, handle::CuPtr{Cvoid})::ncclResult_t +end + +@checked function pncclCommDeregister(comm, handle) + @ccall libnccl.pncclCommDeregister(comm::ncclComm_t, handle::CuPtr{Cvoid})::ncclResult_t +end + @cenum ncclRedOp_dummy_t::UInt32 begin ncclNumOps_dummy = 5 end @@ -231,7 +283,10 @@ end ncclFloat = 7 ncclFloat64 = 8 ncclDouble = 8 - ncclNumTypes = 9 + ncclBfloat16 = 9 + ncclFloat8e4m3 = 10 + ncclFloat8e5m2 = 11 + ncclNumTypes = 12 end @cenum ncclScalarResidence_t::UInt32 begin @@ -355,53 +410,43 @@ end peer::Cint, comm::ncclComm_t, stream::CUstream)::ncclResult_t end -# no prototype is found for this function at nccl.h:416:15, please use with caution +# no prototype is found for this function at nccl.h:454:15, please use with caution @checked function ncclGroupStart() @ccall libnccl.ncclGroupStart()::ncclResult_t end -# no prototype is found for this function at nccl.h:417:14, please use with caution +# no prototype is found for this function at nccl.h:455:14, please use with caution @checked function pncclGroupStart() @ccall libnccl.pncclGroupStart()::ncclResult_t end -# no prototype is found for this function at nccl.h:426:15, please use with caution +# no prototype is found for this function at nccl.h:464:15, please use with caution @checked function ncclGroupEnd() @ccall libnccl.ncclGroupEnd()::ncclResult_t end -# no prototype is found for this function at nccl.h:427:14, please use with caution +# no prototype is found for this function at nccl.h:465:14, please use with caution @checked function pncclGroupEnd() @ccall libnccl.pncclGroupEnd()::ncclResult_t end -@checked function ncclCommRegister(comm, buff, size, handle) - @ccall libnccl.ncclCommRegister(comm::ncclComm_t, buff::CuPtr{Cvoid}, size::Cint, - handle::Ptr{Ptr{Cvoid}})::ncclResult_t -end - -@checked function pncclCommRegister(comm, buff, size, handle) - @ccall libnccl.pncclCommRegister(comm::ncclComm_t, buff::CuPtr{Cvoid}, size::Cint, - handle::Ptr{Ptr{Cvoid}})::ncclResult_t -end - -@checked function ncclCommDeregister(comm, handle) - @ccall libnccl.ncclCommDeregister(comm::ncclComm_t, handle::CuPtr{Cvoid})::ncclResult_t +@checked function ncclGroupSimulateEnd(simInfo) + @ccall libnccl.ncclGroupSimulateEnd(simInfo::Ptr{ncclSimInfo_t})::ncclResult_t end -@checked function pncclCommDeregister(comm, handle) - @ccall libnccl.pncclCommDeregister(comm::ncclComm_t, handle::CuPtr{Cvoid})::ncclResult_t +@checked function pncclGroupSimulateEnd(simInfo) + @ccall libnccl.pncclGroupSimulateEnd(simInfo::Ptr{ncclSimInfo_t})::ncclResult_t end const NCCL_MAJOR = 2 -const NCCL_MINOR = 19 +const NCCL_MINOR = 26 -const NCCL_PATCH = 4 +const NCCL_PATCH = 5 const NCCL_SUFFIX = "" -const NCCL_VERSION_CODE = 21904 +const NCCL_VERSION_CODE = 22605 const NCCL_COMM_NULL = NULL @@ -413,7 +458,12 @@ const NCCL_CONFIG_UNDEF_PTR = NULL const NCCL_SPLIT_NOCOLOR = -1 -# Skipping MacroDefinition: NCCL_CONFIG_INITIALIZER { sizeof ( ncclConfig_t ) , /* size */ 0xcafebeef , /* magic */ NCCL_VERSION ( NCCL_MAJOR , NCCL_MINOR , NCCL_PATCH ) , /* version */ NCCL_CONFIG_UNDEF_INT , /* blocking */ NCCL_CONFIG_UNDEF_INT , /* cgaClusterSize */ NCCL_CONFIG_UNDEF_INT , /* minCTAs */ NCCL_CONFIG_UNDEF_INT , /* maxCTAs */ NCCL_CONFIG_UNDEF_PTR , /* netName */ NCCL_CONFIG_UNDEF_INT /* splitShare */ \ +const NCCL_UNDEF_FLOAT = -(Float32(1.0)) + +# Skipping MacroDefinition: NCCL_CONFIG_INITIALIZER { sizeof ( ncclConfig_t ) , /* size */ 0xcafebeef , /* magic */ NCCL_VERSION ( NCCL_MAJOR , NCCL_MINOR , NCCL_PATCH ) , /* version */ NCCL_CONFIG_UNDEF_INT , /* blocking */ NCCL_CONFIG_UNDEF_INT , /* cgaClusterSize */ NCCL_CONFIG_UNDEF_INT , /* minCTAs */ NCCL_CONFIG_UNDEF_INT , /* maxCTAs */ NCCL_CONFIG_UNDEF_PTR , /* netName */ NCCL_CONFIG_UNDEF_INT , /* splitShare */ NCCL_CONFIG_UNDEF_INT , /* trafficClass */ \ +#} + +# Skipping MacroDefinition: NCCL_SIM_INFO_INITIALIZER { sizeof ( ncclSimInfo_t ) , /* size */ 0x74685283 , /* magic */ NCCL_VERSION ( NCCL_MAJOR , NCCL_MINOR , NCCL_PATCH ) , /* version */ NCCL_UNDEF_FLOAT /* estimated time */ \ #} export NCCLError diff --git a/src/pointtopoint.jl b/src/pointtopoint.jl index c0d9544..b712dee 100644 --- a/src/pointtopoint.jl +++ b/src/pointtopoint.jl @@ -13,9 +13,9 @@ called. """ function Send(sendbuf, comm::Communicator; dest::Integer, stream::CuStream=default_device_stream(comm)) - count = length(sendbuf) + a_count = count(sendbuf) datatype = ncclDataType_t(eltype(sendbuf)) - ncclSend(sendbuf, count, datatype, dest, comm, stream) + ncclSend(sendbuf, a_count, datatype, dest, comm, stream) return nothing end @@ -34,8 +34,8 @@ Write the data from a matching [`Send`](@ref) on rank `source` into `recvbuf`. """ function Recv!(recvbuf, comm::Communicator; source::Integer, stream::CuStream=default_device_stream(comm)) - count = length(recvbuf) + a_count = count(recvbuf) datatype = ncclDataType_t(eltype(recvbuf)) - ncclRecv(recvbuf, count, datatype, source, comm, stream) + ncclRecv(recvbuf, a_count, datatype, source, comm, stream) return recvbuf.data end diff --git a/test/runtests.jl b/test/runtests.jl index df695c0..1dc2dfb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,171 +26,187 @@ end devs = CUDA.devices() comms = NCCL.Communicators(devs) - @testset "sum" begin - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - N = 512 + @testset "$T" for T in (Float64, ComplexF64) + @testset "sum" begin + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), N)) + recvbuf[ii] = CUDA.zeros(T, N) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + end + end + answer = sum(1:length(devs)) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) + end + end + + @testset "NCCL.avg" begin + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), N)) + recvbuf[ii] = CUDA.zeros(T, N) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii]) + end + end + answer = sum(1:length(devs)) / length(devs) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .≈ answer) + end + end + end +end + +@testset "Broadcast!" begin + devs = CUDA.devices() + comms = NCCL.Communicators(devs) + + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + root = 0 for (ii, dev) in enumerate(devs) CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), N)) - recvbuf[ii] = CUDA.zeros(Float64, N) + sendbuf[ii] = (ii - 1) == root ? CuArray(fill(T(1.0), 512)) : CUDA.zeros(T, 512) + recvbuf[ii] = CUDA.zeros(T, 512) end NCCL.group() do for ii in 1:length(devs) - NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + NCCL.Broadcast!(sendbuf[ii], recvbuf[ii], comms[ii]; root) end end - answer = sum(1:length(devs)) + answer = 1.0 for (ii, dev) in enumerate(devs) device!(ii - 1) crecv = collect(recvbuf[ii]) @test all(crecv .== answer) end end +end - @testset "NCCL.avg" begin - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - N = 512 +@testset "Reduce!" begin + devs = CUDA.devices() + comms = NCCL.Communicators(devs) + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + root = 0 for (ii, dev) in enumerate(devs) CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), N)) - recvbuf[ii] = CUDA.zeros(Float64, N) + sendbuf[ii] = CuArray(fill(T(ii), 512)) + recvbuf[ii] = CUDA.zeros(T, 512) end NCCL.group() do for ii in 1:length(devs) - NCCL.Allreduce!(sendbuf[ii], recvbuf[ii], NCCL.avg, comms[ii]) + NCCL.Reduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]; root) end end - answer = sum(1:length(devs)) / length(devs) for (ii, dev) in enumerate(devs) + answer = (ii - 1) == root ? sum(1:length(devs)) : 0.0 device!(ii - 1) crecv = collect(recvbuf[ii]) - @test all(crecv .≈ answer) + @test all(crecv .== answer) end end end -@testset "Broadcast!" begin +@testset "Allgather!" begin devs = CUDA.devices() comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - root = 0 - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = (ii - 1) == root ? CuArray(fill(Float64(1.0), 512)) : CUDA.zeros(Float64, 512) - recvbuf[ii] = CUDA.zeros(Float64, 512) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.Broadcast!(sendbuf[ii], recvbuf[ii], comms[ii]; root) - end - end - answer = 1.0 - for (ii, dev) in enumerate(devs) - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) - end -end -@testset "Reduce!" begin - devs = CUDA.devices() - comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - root = 0 - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), 512)) - recvbuf[ii] = CUDA.zeros(Float64, 512) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.Reduce!(sendbuf[ii], recvbuf[ii], +, comms[ii]; root) + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), 512)) + recvbuf[ii] = CUDA.zeros(T, length(devs)*512) end - end - for (ii, dev) in enumerate(devs) - answer = (ii - 1) == root ? sum(1:length(devs)) : 0.0 - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) - end -end - -@testset "Allgather!" begin - devs = CUDA.devices() - comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), 512)) - recvbuf[ii] = CUDA.zeros(Float64, length(devs)*512) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.Allgather!(sendbuf[ii], recvbuf[ii], comms[ii]) + NCCL.group() do + for ii in 1:length(devs) + NCCL.Allgather!(sendbuf[ii], recvbuf[ii], comms[ii]) + end + end + answer = vec(repeat(1:length(devs), inner=512)) + for (ii, dev) in enumerate(devs) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) end - end - answer = vec(repeat(1:length(devs), inner=512)) - for (ii, dev) in enumerate(devs) - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) end end @testset "ReduceScatter!" begin devs = CUDA.devices() comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(vec(repeat(collect(1:length(devs)), inner=2))) - recvbuf[ii] = CUDA.zeros(Float64, 2) - end - NCCL.group() do - for ii in 1:length(devs) - NCCL.ReduceScatter!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(vec(repeat(collect(1:length(devs)), inner=2))) + recvbuf[ii] = CUDA.zeros(T, 2) + end + NCCL.group() do + for ii in 1:length(devs) + NCCL.ReduceScatter!(sendbuf[ii], recvbuf[ii], +, comms[ii]) + end + end + for (ii, dev) in enumerate(devs) + answer = length(devs)*ii + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) end - end - for (ii, dev) in enumerate(devs) - answer = length(devs)*ii - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) end end @testset "Send/Recv" begin devs = CUDA.devices() comms = NCCL.Communicators(devs) - recvbuf = Vector{CuVector{Float64}}(undef, length(devs)) - sendbuf = Vector{CuVector{Float64}}(undef, length(devs)) - N = 512 - for (ii, dev) in enumerate(devs) - CUDA.device!(ii - 1) - sendbuf[ii] = CuArray(fill(Float64(ii), N)) - recvbuf[ii] = CUDA.zeros(Float64, N) - end - NCCL.group() do - for ii in 1:length(devs) - comm = comms[ii] - dest = mod(NCCL.rank(comm)+1, NCCL.size(comm)) - source = mod(NCCL.rank(comm)-1, NCCL.size(comm)) - NCCL.Send(sendbuf[ii], comm; dest) - NCCL.Recv!(recvbuf[ii], comm; source) + @testset "$T" for T in (Float64, ComplexF64) + recvbuf = Vector{CuVector{T}}(undef, length(devs)) + sendbuf = Vector{CuVector{T}}(undef, length(devs)) + N = 512 + for (ii, dev) in enumerate(devs) + CUDA.device!(ii - 1) + sendbuf[ii] = CuArray(fill(T(ii), N)) + recvbuf[ii] = CUDA.zeros(T, N) + end + + NCCL.group() do + for ii in 1:length(devs) + comm = comms[ii] + dest = mod(NCCL.rank(comm)+1, NCCL.size(comm)) + source = mod(NCCL.rank(comm)-1, NCCL.size(comm)) + NCCL.Send(sendbuf[ii], comm; dest) + NCCL.Recv!(recvbuf[ii], comm; source) + end + end + for (ii, dev) in enumerate(devs) + answer = mod1(ii - 1, length(devs)) + device!(ii - 1) + crecv = collect(recvbuf[ii]) + @test all(crecv .== answer) end - end - for (ii, dev) in enumerate(devs) - answer = mod1(ii - 1, length(devs)) - device!(ii - 1) - crecv = collect(recvbuf[ii]) - @test all(crecv .== answer) end end