From c5f72cd22ad56e27de58258b3e4844cb58378c49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 15 Feb 2025 19:52:06 +0100 Subject: [PATCH 01/97] Register MPI symbols on load --- ext/ReactantMPIExt.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/ext/ReactantMPIExt.jl b/ext/ReactantMPIExt.jl index 5ede919c0d..ced835e4ed 100644 --- a/ext/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt.jl @@ -1,5 +1,6 @@ module ReactantMPIExt +using Reactant using Reactant: Reactant, Distributed using MPI: MPI @@ -33,4 +34,22 @@ function Distributed.get_local_process_id(::Distributed.MPIEnvDetector) return Int(MPI.Comm_rank(new_comm)) end +function __init__() + for name in ( + "MPI_Init", + "MPI_Finalize", + "MPI_Comm_rank", + "MPI_Comm_size", + "MPI_Send", + "MPI_Recv", + "MPI_Isend", + "MPI_Irecv", + "MPI_Wait", + "MPI_Request_free", + ) + sym = Libdl.dlsym(MPI.API.libmpi, name) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, sym::Ptr{Cvoid})::Cvoid + end +end + end From d5eaa2ddd2cd11e48d82db06b7842dd268ea792c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 16 Feb 2025 10:30:18 +0100 Subject: [PATCH 02/97] ops --- ext/ReactantMPIExt/Ops.jl | 80 +++++++++++++++ ext/ReactantMPIExt/Overrides.jl | 110 +++++++++++++++++++++ ext/{ => ReactantMPIExt}/ReactantMPIExt.jl | 8 ++ 3 files changed, 198 insertions(+) create mode 100644 ext/ReactantMPIExt/Ops.jl create mode 100644 ext/ReactantMPIExt/Overrides.jl rename ext/{ => ReactantMPIExt}/ReactantMPIExt.jl (90%) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl new file mode 100644 index 0000000000..eac8e27796 --- /dev/null +++ b/ext/ReactantMPIExt/Ops.jl @@ -0,0 +1,80 @@ +module Ops +using Reactant: TracedRArray, TracedRNumber +using Reactant: MLIR +using Reactant.MLIR.Dialects: mpi +using ..ReactantMPIExt: TracedRequest + +# TODO add communicators + +function init(; location=mlir_stacktrace("mpi.init", @__FILE__, @__LINE__)) + return mpi.init(; location) +end + +function finalize(; location=mlir_stacktrace("mpi.finalize", @__FILE__, @__LINE__)) + return mpi.finalize(; location) +end + +function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__)) + res = MLIR.IR.result(mpi.comm_rank(; location)) + return TracedRNumber{Int}((), res) +end + +function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__)) + res = MLIR.IR.result(mpi.comm_size(; location)) + return TracedRNumber{Int}((), res) +end + +# TODO should we emit `stablehlo.optimization_barrier` here too? +function barrier(; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__)) + return mpi.barrier(; location) +end + +function send( + buf::TracedRArray, + tag::TracedRNumber, + dest::TracedRNumber; + location=mlir_stacktrace("mpi.send", @__FILE__, @__LINE__), +) + return mpi.send(buf.mlir_data, tag.mlir_data, dest.mlir_data; location) +end + +# TODO need c-function for creating MLIR `mpi.request` type? +function isend( + buf::TracedRArray, + tag::TracedRNumber, + dest::TracedRNumber; + location=mlir_stacktrace("mpi.isend", @__FILE__, @__LINE__), +) + return TracedRequest( + MLIR.IR.result(mpi.isend(buf.mlir_data, tag.mlir_data, dest.mlir_data; location)) + ) +end + +function recv!( + ref::TracedRArray, + tag::TracedRNumber, + src::TracedRNumber; + location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__), +) + return mpi.recv(ref.mlir_data, tag.mlir_data, src.mlir_data; location) +end + +# TODO need c-function for creating MLIR `mpi.request` type? +function irecv!( + ref::TracedRArray, + tag::TracedRNumber, + src::TracedRNumber; + location=mlir_stacktrace("mpi.irecv", @__FILE__, @__LINE__), +) + return TracedRequest( + MLIR.IR.result(mpi.irecv(ref.mlir_data, tag.mlir_data, src.mlir_data; location)) + ) +end + +function wait( + req::TracedRequest; location=mlir_stacktrace("mpi.wait", @__FILE__, @__LINE__) +) + return mpi.wait(req.mlir_data; location) +end + +end # module diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl new file mode 100644 index 0000000000..984c7355c0 --- /dev/null +++ b/ext/ReactantMPIExt/Overrides.jl @@ -0,0 +1,110 @@ +@reactant_overlay @noinline function MPI.Init(; kwargs...) + if !isempty(kwargs) + @warn "Ignoring MPI.Init kwargs when tracing over MPI..." kwargs... + end + return Ops.init() +end + +@reactant_overlay @noinline function MPI.Init(; kwargs...) + return Ops.finalize() +end + +@reactant_overlay @noinline function MPI.Comm_rank(comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.comm_rank() +end + +@reactant_overlay @noinline function MPI.Comm_size(comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.comm_size() +end + +@reactant_overlay @noinline function MPI.Barrier(comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.barrier() +end + +# TODO status not supported yet +function MPI.Wait(req::TracedRequest) + return Ops.wait(req) +end + +# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` +function MPI.Send(buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + + tag = if !(tag isa TracedRNumber) + Ops.constant(tag) + end + + dest = if !(dest isa TracedRNumber) + Ops.constant(dest) + end + + return Ops.send(buf, tag, dest) +end + +# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` +function MPI.Isend(buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + + tag = if !(tag isa TracedRNumber) + Ops.constant(tag) + end + + return dest = if !(dest isa TracedRNumber) + Ops.constant(dest) + end + + return Ops.isend(buf, tag, dest) +end + +# TODO should we error if other `AbstractRequest` types are passed in? +function MPI.Isend( + buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm, req::TracedRequest +) + gen_req = MPI.Isend(buf, dest, tag, comm) + req.mlir_data = gen_req.mlir_data + return req +end + +# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` +function MPI.Recv!( + recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm, status +) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + @assert isnothing(status) "Status not supported yet" + + tag = if !(tag isa TracedRNumber) + Ops.constant(tag) + end + + source = if !(source isa TracedRNumber) + Ops.constant(source) + end + + return Ops.recv(recvbuf, tag, source) +end + +# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` +function MPI.IRecv!(recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + + tag = if !(tag isa TracedRNumber) + Ops.constant(tag) + end + + source = if !(source isa TracedRNumber) + Ops.constant(source) + end + + return Ops.irecv!(recvbuf, tag, source) +end + +function MPI.IRecv!( + recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm, req::TracedRequest +) + gen_req = MPI.IRecv!(recvbuf, source, tag, comm) + req.mlir_data = gen_req.mlir_data + return req +end diff --git a/ext/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl similarity index 90% rename from ext/ReactantMPIExt.jl rename to ext/ReactantMPIExt/ReactantMPIExt.jl index ced835e4ed..34ff037d7a 100644 --- a/ext/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -44,6 +44,7 @@ function __init__() "MPI_Recv", "MPI_Isend", "MPI_Irecv", + "MPI_Barrier", "MPI_Wait", "MPI_Request_free", ) @@ -52,4 +53,11 @@ function __init__() end end +struct TracedRequest <: MPI.AbstractRequest + mlir_data::Union{Nothing,MLIR.IR.Value} end + +include("Ops.jl") +include("Overrides.jl") + +end # module From 5f800fef98ea1c976bcf579809f2df6025e21958 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Sun, 16 Feb 2025 12:12:43 +0100 Subject: [PATCH 03/97] Update ext/ReactantMPIExt/Overrides.jl Co-authored-by: Paul Berg --- ext/ReactantMPIExt/Overrides.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 984c7355c0..d09cdfaefd 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -5,7 +5,7 @@ return Ops.init() end -@reactant_overlay @noinline function MPI.Init(; kwargs...) +@reactant_overlay @noinline function MPI.Finalize(; kwargs...) return Ops.finalize() end From 5e4a8cd7b2b5b450d1f0da6e5faaa7126954860e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Sat, 22 Feb 2025 19:20:27 +0100 Subject: [PATCH 04/97] Update ext/ReactantMPIExt/Overrides.jl --- ext/ReactantMPIExt/Overrides.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index d09cdfaefd..627c0c1940 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -52,7 +52,7 @@ function MPI.Isend(buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm) Ops.constant(tag) end - return dest = if !(dest isa TracedRNumber) + dest = if !(dest isa TracedRNumber) Ops.constant(dest) end From 215cb1dc4dbb9d839871b2d14c4237d7a51e36f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 28 Feb 2025 12:06:46 +0100 Subject: [PATCH 05/97] register MPI constants --- ext/ReactantMPIExt/ReactantMPIExt.jl | 161 ++++++++++++++++++++++++++- 1 file changed, 159 insertions(+), 2 deletions(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 34ff037d7a..2046beea59 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -35,7 +35,8 @@ function Distributed.get_local_process_id(::Distributed.MPIEnvDetector) end function __init__() - for name in ( + # register MPI routines + for name in [ "MPI_Init", "MPI_Finalize", "MPI_Comm_rank", @@ -47,10 +48,166 @@ function __init__() "MPI_Barrier", "MPI_Wait", "MPI_Request_free", - ) + ] sym = Libdl.dlsym(MPI.API.libmpi, name) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, sym::Ptr{Cvoid})::Cvoid end + + # register MPI constants + # NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, they are represented as word-size values (i.e. `int` or ptr) + for name in [ + # communicators + :MPI_COMM_WORLD, + :MPI_COMM_SELF, + :MPI_COMM_NULL, + :MPI_COMM_TYPE_SHARED, + # datatypes + :MPI_DATATYPE_NULL, + :MPI_BYTE, + :MPI_PACKED, + :MPI_CHAR, + :MPI_SHORT, + :MPI_INT, + :MPI_LONG, + :MPI_FLOAT, + :MPI_DOUBLE, + :MPI_UNSIGNED_CHAR, + :MPI_SIGNED_CHAR, + :MPI_UNSIGNED_SHORT, + :MPI_UNSIGNED_LONG, + :MPI_UNSIGNED, + :MPI_FLOAT_INT, + :MPI_DOUBLE_INT, + :MPI_LONG_DOUBLE_INT, + :MPI_LONG_INT, + :MPI_SHORT_INT, + :MPI_2INT, + :MPI_UB, + :MPI_LB, + :MPI_WCHAR, + :MPI_LONG_LONG_INT, + :MPI_UNSIGNED_LONG_LONG, + :MPI_2COMPLEX, + :MPI_2DOUBLE_COMPLEX, + :MPI_INT8_T, + :MPI_UINT8_T, + :MPI_INT16_T, + :MPI_UINT16_T, + :MPI_INT32_T, + :MPI_UINT32_T, + :MPI_INT64_T, + :MPI_UINT64_T, + :MPI_AINT, + :MPI_OFFSET, + :MPI_C_BOOL, + :MPI_C_FLOAT_COMPLEX, + :MPI_C_DOUBLE_COMPLEX, + :MPI_C_LONG_DOUBLE_COMPLEX, + :MPI_COUNT, + # ops + :MPI_OP_NULL, + :MPI_MAX, + :MPI_MIN, + :MPI_SUM, + :MPI_PROD, + :MPI_LAND, + :MPI_BAND, + :MPI_LOR, + :MPI_BOR, + :MPI_LXOR, + :MPI_BXOR, + :MPI_MINLOC, + :MPI_MAXLOC, + :MPI_REPLACE, + :MPI_NO_OP, + # request + :MPI_REQUEST_NULL, + # status + :MPI_STATUS_IGNORE, + :MPI_STATUSES_IGNORE, + # error + :MPI_SUCCESS, + :MPI_ERR_BUFFER, + :MPI_ERR_COUNT, + :MPI_ERR_TYPE, + :MPI_ERR_TAG, + :MPI_ERR_COMM, + :MPI_ERR_RANK, + :MPI_ERR_REQUEST, + :MPI_ERR_ROOT, + :MPI_ERR_GROUP, + :MPI_ERR_OP, + :MPI_ERR_TOPOLOGY, + :MPI_ERR_DIMS, + :MPI_ERR_ARG, + :MPI_ERR_UNKNOWN, + :MPI_ERR_TRUNCATE, + :MPI_ERR_OTHER, + :MPI_ERR_INTERN, + :MPI_ERR_IN_STATUS, + :MPI_ERR_PENDING, + :MPI_ERR_ACCESS, + :MPI_ERR_AMODE, + :MPI_ERR_ASSERT, + :MPI_ERR_BAD_FILE, + :MPI_ERR_BASE, + :MPI_ERR_CONVERSION, + :MPI_ERR_DISP, + :MPI_ERR_DUP_DATAREP, + :MPI_ERR_FILE_EXISTS, + :MPI_ERR_FILE_IN_USE, + :MPI_ERR_FILE, + :MPI_ERR_INFO_KEY, + :MPI_ERR_INFO_NOKEY, + :MPI_ERR_INFO_VALUE, + :MPI_ERR_INFO, + :MPI_ERR_IO, + :MPI_ERR_KEYVAL, + :MPI_ERR_LOCKTYPE, + :MPI_ERR_NAME, + :MPI_ERR_NO_MEM, + :MPI_ERR_NOT_SAME, + :MPI_ERR_NO_SPACE, + :MPI_ERR_NO_SUCH_FILE, + :MPI_ERR_PORT, + :MPI_ERR_QUOTA, + :MPI_ERR_READ_ONLY, + :MPI_ERR_RMA_CONFLICT, + :MPI_ERR_RMA_SYNC, + :MPI_ERR_SERVICE, + :MPI_ERR_SIZE, + :MPI_ERR_SPAWN, + :MPI_ERR_UNSUPPORTED_DATAREP, + :MPI_ERR_UNSUPPORTED_OPERATION, + :MPI_ERR_WIN, + :MPI_T_ERR_MEMORY, + :MPI_T_ERR_NOT_INITIALIZED, + :MPI_T_ERR_CANNOT_INIT, + :MPI_T_ERR_INVALID_INDEX, + :MPI_T_ERR_INVALID_ITEM, + :MPI_T_ERR_INVALID_HANDLE, + :MPI_T_ERR_OUT_OF_HANDLES, + :MPI_T_ERR_OUT_OF_SESSIONS, + :MPI_T_ERR_INVALID_SESSION, + :MPI_T_ERR_CVAR_SET_NOT_NOW, + :MPI_T_ERR_CVAR_SET_NEVER, + :MPI_T_ERR_PVAR_NO_STARTSTOP, + :MPI_T_ERR_PVAR_NO_WRITE, + :MPI_T_ERR_PVAR_NO_ATOMIC, + :MPI_ERR_RMA_RANGE, + :MPI_ERR_RMA_ATTACH, + :MPI_ERR_RMA_FLAVOR, + :MPI_ERR_RMA_SHARED, + :MPI_T_ERR_INVALID, + :MPI_T_ERR_INVALID_NAME, + :MPI_ERR_PROC_ABORTED, + :MPI_ERR_PROC_FAILED, + :MPI_ERR_PROC_FAILED_PENDING, + :MPI_ERR_REVOKED, + ] + value = getproperty(MPI.API, name) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Cint)::Cvoid + end end struct TracedRequest <: MPI.AbstractRequest From 4b54755ae46a1062e263225a91d442d39187d242 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 3 Mar 2025 23:14:29 +0100 Subject: [PATCH 06/97] Fix MPI specializations --- ext/ReactantMPIExt/Overrides.jl | 47 +++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 627c0c1940..c6845f60b0 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -1,3 +1,5 @@ +using Reactant: @reactant_overlay, TracedRArray, TracedRNumber + @reactant_overlay @noinline function MPI.Init(; kwargs...) if !isempty(kwargs) @warn "Ignoring MPI.Init kwargs when tracing over MPI..." kwargs... @@ -30,30 +32,35 @@ function MPI.Wait(req::TracedRequest) end # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` -function MPI.Send(buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm) - @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - - tag = if !(tag isa TracedRNumber) - Ops.constant(tag) - end - - dest = if !(dest isa TracedRNumber) - Ops.constant(dest) - end +function MPI.Send(buf::TracedRArray, dest::Integer, tag::Integer, comm::MPI.Comm) + tag = Reactant.Ops.constant(tag) + dest = Reactant.Ops.constant(dest) + return MPI.Send(buf, dest, tag, comm) +end +# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` +function MPI.Send( + buf::TracedRArray, dest::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm +) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" return Ops.send(buf, tag, dest) end # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` -function MPI.Isend(buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm) +function MPI.Isend( + buf::TracedRArray, + dest::Union{T,TracedRNumber{T}}, + tag::Union{T,TracedRNumber{T}}, + comm::MPI.Comm, +) where {T<:Integer} @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" tag = if !(tag isa TracedRNumber) - Ops.constant(tag) + Reactant.Ops.constant(tag) end dest = if !(dest isa TracedRNumber) - Ops.constant(dest) + Reactant.Ops.constant(dest) end return Ops.isend(buf, tag, dest) @@ -76,35 +83,35 @@ function MPI.Recv!( @assert isnothing(status) "Status not supported yet" tag = if !(tag isa TracedRNumber) - Ops.constant(tag) + Reactant.Ops.constant(tag) end source = if !(source isa TracedRNumber) - Ops.constant(source) + Reactant.Ops.constant(source) end return Ops.recv(recvbuf, tag, source) end # TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` -function MPI.IRecv!(recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm) +function MPI.Irecv!(recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" tag = if !(tag isa TracedRNumber) - Ops.constant(tag) + Reactant.Ops.constant(tag) end source = if !(source isa TracedRNumber) - Ops.constant(source) + Reactant.Ops.constant(source) end return Ops.irecv!(recvbuf, tag, source) end -function MPI.IRecv!( +function MPI.Irecv!( recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm, req::TracedRequest ) - gen_req = MPI.IRecv!(recvbuf, source, tag, comm) + gen_req = MPI.Irecv!(recvbuf, source, tag, comm) req.mlir_data = gen_req.mlir_data return req end From e7fd20ed3dcd061afaee8869f6c7e5934e0f471d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 3 Mar 2025 23:15:04 +0100 Subject: [PATCH 07/97] fix some symbol registration --- ext/ReactantMPIExt/ReactantMPIExt.jl | 83 +++++++++++++++------------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 2046beea59..bec689e9e3 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -1,8 +1,9 @@ module ReactantMPIExt using Reactant -using Reactant: Reactant, Distributed +using Reactant: Reactant, Distributed, MLIR using MPI: MPI +using Libdl # https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized() @@ -35,21 +36,24 @@ function Distributed.get_local_process_id(::Distributed.MPIEnvDetector) end function __init__() + # TODO maybe it's more efficient if we use `RTLD_NOW` instead of `RTLD_LAZY`? + libmpi_handle = Libdl.dlopen(MPI.API.libmpi, RTLD_LAZY | RTLD_GLOBAL) + # register MPI routines for name in [ - "MPI_Init", - "MPI_Finalize", - "MPI_Comm_rank", - "MPI_Comm_size", - "MPI_Send", - "MPI_Recv", - "MPI_Isend", - "MPI_Irecv", - "MPI_Barrier", - "MPI_Wait", - "MPI_Request_free", + :MPI_Init, + :MPI_Finalize, + :MPI_Comm_rank, + :MPI_Comm_size, + :MPI_Send, + :MPI_Recv, + :MPI_Isend, + :MPI_Irecv, + :MPI_Barrier, + :MPI_Wait, + :MPI_Request_free, ] - sym = Libdl.dlsym(MPI.API.libmpi, name) + sym = Libdl.dlsym(libmpi_handle, name) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, sym::Ptr{Cvoid})::Cvoid end @@ -81,14 +85,14 @@ function __init__() :MPI_LONG_DOUBLE_INT, :MPI_LONG_INT, :MPI_SHORT_INT, - :MPI_2INT, + # :MPI_2INT, :MPI_UB, :MPI_LB, :MPI_WCHAR, :MPI_LONG_LONG_INT, :MPI_UNSIGNED_LONG_LONG, - :MPI_2COMPLEX, - :MPI_2DOUBLE_COMPLEX, + # :MPI_2COMPLEX, + # :MPI_2DOUBLE_COMPLEX, :MPI_INT8_T, :MPI_UINT8_T, :MPI_INT16_T, @@ -102,7 +106,7 @@ function __init__() :MPI_C_BOOL, :MPI_C_FLOAT_COMPLEX, :MPI_C_DOUBLE_COMPLEX, - :MPI_C_LONG_DOUBLE_COMPLEX, + # :MPI_C_LONG_DOUBLE_COMPLEX, :MPI_COUNT, # ops :MPI_OP_NULL, @@ -180,38 +184,41 @@ function __init__() :MPI_ERR_UNSUPPORTED_DATAREP, :MPI_ERR_UNSUPPORTED_OPERATION, :MPI_ERR_WIN, - :MPI_T_ERR_MEMORY, - :MPI_T_ERR_NOT_INITIALIZED, - :MPI_T_ERR_CANNOT_INIT, - :MPI_T_ERR_INVALID_INDEX, - :MPI_T_ERR_INVALID_ITEM, - :MPI_T_ERR_INVALID_HANDLE, - :MPI_T_ERR_OUT_OF_HANDLES, - :MPI_T_ERR_OUT_OF_SESSIONS, - :MPI_T_ERR_INVALID_SESSION, - :MPI_T_ERR_CVAR_SET_NOT_NOW, - :MPI_T_ERR_CVAR_SET_NEVER, - :MPI_T_ERR_PVAR_NO_STARTSTOP, - :MPI_T_ERR_PVAR_NO_WRITE, - :MPI_T_ERR_PVAR_NO_ATOMIC, + # :MPI_T_ERR_MEMORY, + # :MPI_T_ERR_NOT_INITIALIZED, + # :MPI_T_ERR_CANNOT_INIT, + # :MPI_T_ERR_INVALID_INDEX, + # :MPI_T_ERR_INVALID_ITEM, + # :MPI_T_ERR_INVALID_HANDLE, + # :MPI_T_ERR_OUT_OF_HANDLES, + # :MPI_T_ERR_OUT_OF_SESSIONS, + # :MPI_T_ERR_INVALID_SESSION, + # :MPI_T_ERR_CVAR_SET_NOT_NOW, + # :MPI_T_ERR_CVAR_SET_NEVER, + # :MPI_T_ERR_PVAR_NO_STARTSTOP, + # :MPI_T_ERR_PVAR_NO_WRITE, + # :MPI_T_ERR_PVAR_NO_ATOMIC, :MPI_ERR_RMA_RANGE, :MPI_ERR_RMA_ATTACH, :MPI_ERR_RMA_FLAVOR, :MPI_ERR_RMA_SHARED, - :MPI_T_ERR_INVALID, - :MPI_T_ERR_INVALID_NAME, - :MPI_ERR_PROC_ABORTED, - :MPI_ERR_PROC_FAILED, - :MPI_ERR_PROC_FAILED_PENDING, - :MPI_ERR_REVOKED, + # :MPI_T_ERR_INVALID, + # :MPI_T_ERR_INVALID_NAME, + # :MPI_ERR_PROC_ABORTED, + # :MPI_ERR_PROC_FAILED, + # :MPI_ERR_PROC_FAILED_PENDING, + # :MPI_ERR_REVOKED, ] value = getproperty(MPI.API, name) + if value isa Base.RefValue + value = value[] + end @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Cint)::Cvoid end end struct TracedRequest <: MPI.AbstractRequest - mlir_data::Union{Nothing,MLIR.IR.Value} + mlir_data::Union{Nothing,Reactant.MLIR.IR.Value} end include("Ops.jl") From 774982aad1fef66413903b913fcd9d9469269d0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 3 Mar 2025 23:15:52 +0100 Subject: [PATCH 08/97] refactor MPI Ops --- ext/ReactantMPIExt/Ops.jl | 139 ++++++++++++++++++++++++++++++++------ 1 file changed, 119 insertions(+), 20 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index eac8e27796..412f8b55ed 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -1,41 +1,112 @@ module Ops -using Reactant: TracedRArray, TracedRNumber +using Reactant: Reactant, TracedRArray, TracedRNumber using Reactant: MLIR -using Reactant.MLIR.Dialects: mpi +using Reactant.MLIR: IR +using Reactant.MLIR.IR: @mlir_str +using Reactant.MLIR.Dialects: mpi, func, llvm, enzymexla +using Reactant.Ops: mlir_stacktrace using ..ReactantMPIExt: TracedRequest # TODO add communicators -function init(; location=mlir_stacktrace("mpi.init", @__FILE__, @__LINE__)) - return mpi.init(; location) -end +# function init(; location=mlir_stacktrace("mpi.init", @__FILE__, @__LINE__)) +# return mpi.init(; location) +# end -function finalize(; location=mlir_stacktrace("mpi.finalize", @__FILE__, @__LINE__)) - return mpi.finalize(; location) -end +# function finalize(; location=mlir_stacktrace("mpi.finalize", @__FILE__, @__LINE__)) +# return mpi.finalize(; location) +# end +# TODO emit wrapper if not found function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__)) - res = MLIR.IR.result(mpi.comm_rank(; location)) - return TracedRNumber{Int}((), res) + sym_name = "enzymexla_wrapper_MPI_Comm_rank" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + # rettype = [IR.TensorType(Int[], IR.Type(Cint))] + + current_module = IR.mmodule() + fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) + + if isnothing(fn) + # arg_type = IR.Type[IR.TensorType(Int[], IR.Type(Cint))] + arg_type = IR.Type[MLIR.IR.Type( + MLIR.API.mlirLLVMPointerTypeGet(IR.context(), Cuint(0)) + )] + function_type = IR.FunctionType(arg_type, IR.Type[]) + + @show arg_type function_type + + wrapper = IR.block!(IR.body(current_module)) do + func.func_(; sym_name, function_type, body=IR.Region()) + end + wrapper_body = IR.Block(arg_type, [IR.Location()]) + push!(IR.region(wrapper, 1), wrapper_body) + + # @show wrapper + + # fill the wrapper body + IR.block!(wrapper_body) do + # llvm.call( + # IR.Value[], + # IR.Value[]; + # callee=IR.FlatSymbolRefAttribute("MPI_Comm_rank"), + # op_bundle_sizes=MLIR.IR.Attribute(Cint[]), + # ) + # [IR.Type(Cint)], + # [IR.Type(Cint)], + # [IR.Type(Cint)], + # value = Reactant.Ops.constant(fill(Int32(1))) + # c = IR.result(llvm.mlir_constant(; res=IR.Type(Cint), value=1)) + # llvm.store(c, ...) + func.return_(IR.Value[]) + end + end + + # world = Reactant.Ops.constant(fill(0)) + value_out = Reactant.Ops.constant(fill(0)) + # inputs = IR.Value[world.mlir_data] + inputs = IR.Value[value_out.mlir_data] + + res = IR.result(enzymexla.jit_call(inputs; fn=sym_attr, result_0=IR.Type[], location)) + return TracedRNumber{Cint}((), res) end +# TODO emit wrapper if not found function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__)) - res = MLIR.IR.result(mpi.comm_size(; location)) - return TracedRNumber{Int}((), res) + inputs = IR.Value[] + sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Comm_size") + rettype = [IR.TensorType(Int[], IR.Type(Cint))] + + res = IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) + return TracedRNumber{Cint}((), res) end +# TODO emit wrapper if not found # TODO should we emit `stablehlo.optimization_barrier` here too? function barrier(; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__)) - return mpi.barrier(; location) + inputs = IR.Value[] + sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Barrier") + rettype = IR.Type[] + + # TODO should we return `TracedRNumber{Nothing}`? + return IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) end +# TODO emit wrapper if not found function send( buf::TracedRArray, tag::TracedRNumber, dest::TracedRNumber; location=mlir_stacktrace("mpi.send", @__FILE__, @__LINE__), ) - return mpi.send(buf.mlir_data, tag.mlir_data, dest.mlir_data; location) + # return mpi.send(buf.mlir_data, tag.mlir_data, dest.mlir_data; location) + + # TODO emit constant for size and datatype, and pass as args + inputs = IR.Value[buf.mlir_data, tag.mlir_data, dest.mlir_data] + sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Send") + rettype = IR.Type[] + + return enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location) end # TODO need c-function for creating MLIR `mpi.request` type? @@ -45,8 +116,17 @@ function isend( dest::TracedRNumber; location=mlir_stacktrace("mpi.isend", @__FILE__, @__LINE__), ) + # return TracedRequest( + # IR.result(mpi.isend(buf.mlir_data, tag.mlir_data, dest.mlir_data; location)) + # ) + + # TODO emit constant for size and datatype, and pass as args + inputs = IR.Value[buf.mlir_data, tag.mlir_data, dest.mlir_data] + sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Isend") + rettype = IR.Type[] # TODO return MPI_Request -> use i32 or opaque? + return TracedRequest( - MLIR.IR.result(mpi.isend(buf.mlir_data, tag.mlir_data, dest.mlir_data; location)) + IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) ) end @@ -56,7 +136,15 @@ function recv!( src::TracedRNumber; location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__), ) - return mpi.recv(ref.mlir_data, tag.mlir_data, src.mlir_data; location) + # return mpi.recv(ref.mlir_data, tag.mlir_data, src.mlir_data; location) + + # TODO emit constant for size and datatype, and pass as args + inputs = IR.Value[ref.mlir_data, tag.mlir_data, src.mlir_data] + sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Recv") + rettype = IR.Type[] + + IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) + return ref end # TODO need c-function for creating MLIR `mpi.request` type? @@ -66,15 +154,26 @@ function irecv!( src::TracedRNumber; location=mlir_stacktrace("mpi.irecv", @__FILE__, @__LINE__), ) - return TracedRequest( - MLIR.IR.result(mpi.irecv(ref.mlir_data, tag.mlir_data, src.mlir_data; location)) - ) + # return TracedRequest( + # MLIR.IR.result(mpi.irecv(ref.mlir_data, tag.mlir_data, src.mlir_data; location)) + # ) + inputs = IR.Value[ref.mlir_data, tag.mlir_data, src.mlir_data] + sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Irecv") + rettype = IR.Type[] + + IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) + return ref end function wait( req::TracedRequest; location=mlir_stacktrace("mpi.wait", @__FILE__, @__LINE__) ) - return mpi.wait(req.mlir_data; location) + # return mpi.wait(req.mlir_data; location) + inputs = IR.Value[req.mlir_data] + sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Wait") + rettype = IR.Type[] + + return IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) end end # module From d1fe99b7f136b64035443231c48641b00794c894 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 4 Mar 2025 10:12:55 -0600 Subject: [PATCH 09/97] Add functionality for parsing single operations (Julia code) --- src/mlir/IR/Operation.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index 6f45bbf8ec..5e0ac4d643 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -18,6 +18,17 @@ function Base.unsafe_convert(::Core.Type{API.MlirOperation}, operation::Operatio end Base.:(==)(op::Operation, other::Operation) = API.mlirOperationEqual(op, other) +""" + parse(::Type{Operation}, code; context=context()) + +Parses an operation from the string and transfers ownership to the caller. +""" +Base.parse(::Core.Type{Operation}, code; context::Context=context()) = Operation( + @ccall API.mlir_c.mlirOperationParse( + context::API.MlirContext, code::API.MlirStringRef + )::API.MlirOperation +) + """ copy(op) From 81dd8f2c289c222bd0f59fca5ac9c0c28c98b616 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 4 Mar 2025 10:15:23 -0600 Subject: [PATCH 10/97] Update `Ops.comm_rank` to use handwritten MLIR injection --- ext/ReactantMPIExt/Ops.jl | 69 ++++++++++++++------------------- ext/ReactantMPIExt/Overrides.jl | 2 +- 2 files changed, 31 insertions(+), 40 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 412f8b55ed..dc17fff6c8 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -6,6 +6,7 @@ using Reactant.MLIR.IR: @mlir_str using Reactant.MLIR.Dialects: mpi, func, llvm, enzymexla using Reactant.Ops: mlir_stacktrace using ..ReactantMPIExt: TracedRequest +using MPI: MPI # TODO add communicators @@ -17,57 +18,47 @@ using ..ReactantMPIExt: TracedRequest # return mpi.finalize(; location) # end -# TODO emit wrapper if not found -function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__)) +# TODO we might need to have a `TracedComm` for communicators created during the compiled function +function comm_rank(world; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__)) sym_name = "enzymexla_wrapper_MPI_Comm_rank" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - # rettype = [IR.TensorType(Int[], IR.Type(Cint))] + tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) + signature = IR.Type[tensor_int_type, tensor_int_type] current_module = IR.mmodule() fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) if isnothing(fn) - # arg_type = IR.Type[IR.TensorType(Int[], IR.Type(Cint))] - arg_type = IR.Type[MLIR.IR.Type( - MLIR.API.mlirLLVMPointerTypeGet(IR.context(), Cuint(0)) - )] - function_type = IR.FunctionType(arg_type, IR.Type[]) - - @show arg_type function_type - - wrapper = IR.block!(IR.body(current_module)) do - func.func_(; sym_name, function_type, body=IR.Region()) - end - wrapper_body = IR.Block(arg_type, [IR.Location()]) - push!(IR.region(wrapper, 1), wrapper_body) - - # @show wrapper - - # fill the wrapper body - IR.block!(wrapper_body) do - # llvm.call( - # IR.Value[], - # IR.Value[]; - # callee=IR.FlatSymbolRefAttribute("MPI_Comm_rank"), - # op_bundle_sizes=MLIR.IR.Attribute(Cint[]), - # ) - # [IR.Type(Cint)], - # [IR.Type(Cint)], - # [IR.Type(Cint)], - # value = Reactant.Ops.constant(fill(Int32(1))) - # c = IR.result(llvm.mlir_constant(; res=IR.Type(Cint), value=1)) - # llvm.store(c, ...) - func.return_(IR.Value[]) + top_level_block = MLIR.IR.body(current_module) + #! format: off + code = parse(IR.Module, """ + module { + llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 + func.func @$sym_name(%world_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () { + %world = llvm.load %world_ptr : !llvm.ptr -> i32 + %status = llvm.call @MPI_Comm_rank(%world, %rank_ptr) : (i32, !llvm.ptr) -> (i32) + func.return + } + } + """) |> IR.body + #! format: on + + # using `collect` because if we remove the op, then the `OperationIterator` state is broken and skips ops + for op in collect(IR.OperationIterator(code)) + IR.rmfromparent!(op) + push!(top_level_block, op) end end - # world = Reactant.Ops.constant(fill(0)) - value_out = Reactant.Ops.constant(fill(0)) - # inputs = IR.Value[world.mlir_data] - inputs = IR.Value[value_out.mlir_data] + # NOTE we assume here that `MPI_Comm` is of word-size + world = Reactant.Ops.constant(Base.unsafe_convert(Cint, world)) + value_out = Reactant.Ops.constant(fill(Cint(-1))) + inputs = IR.Value[world.mlir_data, value_out.mlir_data] - res = IR.result(enzymexla.jit_call(inputs; fn=sym_attr, result_0=IR.Type[], location)) + res = IR.result( + enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location), 2 + ) return TracedRNumber{Cint}((), res) end diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index c6845f60b0..a78c9f99e5 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -13,7 +13,7 @@ end @reactant_overlay @noinline function MPI.Comm_rank(comm::MPI.Comm) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - return Ops.comm_rank() + return Ops.comm_rank(comm) end @reactant_overlay @noinline function MPI.Comm_size(comm::MPI.Comm) From 2ebb52a5747b742d229117f26d5eca7dfbbbd589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 4 Mar 2025 10:33:16 -0600 Subject: [PATCH 11/97] comment --- ext/ReactantMPIExt/Ops.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index dc17fff6c8..f83492e266 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -56,6 +56,7 @@ function comm_rank(world; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @ value_out = Reactant.Ops.constant(fill(Cint(-1))) inputs = IR.Value[world.mlir_data, value_out.mlir_data] + # TODO output_operand_aliases res = IR.result( enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location), 2 ) From 11ef6457fe1cbcd5fa87d1ab915edf7c2fd44ad8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 4 Mar 2025 10:49:08 -0600 Subject: [PATCH 12/97] Update `Ops.comm_size` --- ext/ReactantMPIExt/Ops.jl | 42 +++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index f83492e266..2389751c26 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -65,11 +65,45 @@ end # TODO emit wrapper if not found function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__)) - inputs = IR.Value[] - sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Comm_size") - rettype = [IR.TensorType(Int[], IR.Type(Cint))] + sym_name = "enzymexla_wrapper_MPI_Comm_size" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) + signature = IR.Type[tensor_int_type, tensor_int_type] - res = IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) + current_module = IR.mmodule() + fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) + + if isnothing(fn) + top_level_block = MLIR.IR.body(current_module) + #! format: off + code = parse(IR.Module, """ + module { + llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32 + func.func @$sym_name(%world_ptr : !llvm.ptr, %size_ptr : !llvm.ptr) -> () { + %world = llvm.load %world_ptr : !llvm.ptr -> i32 + %status = llvm.call @MPI_Comm_size(%world, %rank_ptr) : (i32, !llvm.ptr) -> (i32) + func.return + } + } + """) |> IR.body + #! format: on + + # using `collect` because if we remove the op, then the `OperationIterator` state is broken and skips ops + for op in collect(IR.OperationIterator(code)) + IR.rmfromparent!(op) + push!(top_level_block, op) + end + end + + world = Reactant.Ops.constant(Base.unsafe_convert(Cint, world)) + value_out = Reactant.Ops.constant(fill(Cint(-1))) + inputs = IR.Value[world.mlir_data, value_out.mlir_data] + + # TODO output_operand_aliases + res = IR.result( + enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location), 2 + ) return TracedRNumber{Cint}((), res) end From 406fe0888122df419ea2443cd4b4f99bb835c896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 4 Mar 2025 10:59:51 -0600 Subject: [PATCH 13/97] fixes --- ext/ReactantMPIExt/Ops.jl | 38 ++++++++++++++++----------------- ext/ReactantMPIExt/Overrides.jl | 2 +- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 2389751c26..58682114f0 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -18,14 +18,10 @@ using MPI: MPI # return mpi.finalize(; location) # end -# TODO we might need to have a `TracedComm` for communicators created during the compiled function -function comm_rank(world; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__)) +function comm_rank(comm; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__)) sym_name = "enzymexla_wrapper_MPI_Comm_rank" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) - signature = IR.Type[tensor_int_type, tensor_int_type] - current_module = IR.mmodule() fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) @@ -35,9 +31,9 @@ function comm_rank(world; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @ code = parse(IR.Module, """ module { llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 - func.func @$sym_name(%world_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () { - %world = llvm.load %world_ptr : !llvm.ptr -> i32 - %status = llvm.call @MPI_Comm_rank(%world, %rank_ptr) : (i32, !llvm.ptr) -> (i32) + func.func @$sym_name(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () { + %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 + %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) func.return } } @@ -52,9 +48,12 @@ function comm_rank(world; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @ end # NOTE we assume here that `MPI_Comm` is of word-size - world = Reactant.Ops.constant(Base.unsafe_convert(Cint, world)) + comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) value_out = Reactant.Ops.constant(fill(Cint(-1))) - inputs = IR.Value[world.mlir_data, value_out.mlir_data] + inputs = IR.Value[comm.mlir_data, value_out.mlir_data] + + tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) + signature = IR.Type[tensor_int_type, tensor_int_type] # TODO output_operand_aliases res = IR.result( @@ -63,14 +62,10 @@ function comm_rank(world; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @ return TracedRNumber{Cint}((), res) end -# TODO emit wrapper if not found -function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__)) +function comm_size(comm; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__)) sym_name = "enzymexla_wrapper_MPI_Comm_size" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) - signature = IR.Type[tensor_int_type, tensor_int_type] - current_module = IR.mmodule() fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) @@ -80,9 +75,9 @@ function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LIN code = parse(IR.Module, """ module { llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32 - func.func @$sym_name(%world_ptr : !llvm.ptr, %size_ptr : !llvm.ptr) -> () { - %world = llvm.load %world_ptr : !llvm.ptr -> i32 - %status = llvm.call @MPI_Comm_size(%world, %rank_ptr) : (i32, !llvm.ptr) -> (i32) + func.func @$sym_name(%comm_ptr : !llvm.ptr, %size_ptr : !llvm.ptr) -> () { + %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 + %status = llvm.call @MPI_Comm_size(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) func.return } } @@ -96,9 +91,12 @@ function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LIN end end - world = Reactant.Ops.constant(Base.unsafe_convert(Cint, world)) + comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) value_out = Reactant.Ops.constant(fill(Cint(-1))) - inputs = IR.Value[world.mlir_data, value_out.mlir_data] + inputs = IR.Value[comm.mlir_data, value_out.mlir_data] + + tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) + signature = IR.Type[tensor_int_type, tensor_int_type] # TODO output_operand_aliases res = IR.result( diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index a78c9f99e5..8d492c3fd8 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -18,7 +18,7 @@ end @reactant_overlay @noinline function MPI.Comm_size(comm::MPI.Comm) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - return Ops.comm_size() + return Ops.comm_size(comm) end @reactant_overlay @noinline function MPI.Barrier(comm::MPI.Comm) From 108c679bb28a279528d7392964e6ea445b7a0f2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 4 Mar 2025 11:00:32 -0600 Subject: [PATCH 14/97] Refactor `Ops.barrier` --- ext/ReactantMPIExt/Ops.jl | 46 ++++++++++++++++++++++++++------- ext/ReactantMPIExt/Overrides.jl | 2 +- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 58682114f0..2f9a3fbdc5 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -8,7 +8,7 @@ using Reactant.Ops: mlir_stacktrace using ..ReactantMPIExt: TracedRequest using MPI: MPI -# TODO add communicators +# TODO we might need to have a `TracedComm` for communicators created during the compiled function # function init(; location=mlir_stacktrace("mpi.init", @__FILE__, @__LINE__)) # return mpi.init(; location) @@ -105,15 +105,43 @@ function comm_size(comm; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @_ return TracedRNumber{Cint}((), res) end -# TODO emit wrapper if not found -# TODO should we emit `stablehlo.optimization_barrier` here too? -function barrier(; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__)) - inputs = IR.Value[] - sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Barrier") - rettype = IR.Type[] +function barrier(comm; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__)) + sym_name = "enzymexla_wrapper_MPI_Barrier" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) - # TODO should we return `TracedRNumber{Nothing}`? - return IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) + tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) + signature = IR.Type[tensor_int_type] + + current_module = IR.mmodule() + fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) + + if isnothing(fn) + top_level_block = MLIR.IR.body(current_module) + #! format: off + code = parse(IR.Module, """ + module { + llvm.func @MPI_Barrier(i32) -> i32 + func.func @$sym_name(%comm_ptr : !llvm.ptr) -> () { + %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 + %status = llvm.call @MPI_Barrier(%comm) : (i32) -> (i32) + func.return + } + } + """) |> IR.body + #! format: on + + # using `collect` because if we remove the op, then the `OperationIterator` state is broken and skips ops + for op in collect(IR.OperationIterator(code)) + IR.rmfromparent!(op) + push!(top_level_block, op) + end + end + + comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) + inputs = [comm.mlir_data] + enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location) + + return nothing end # TODO emit wrapper if not found diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 8d492c3fd8..18d1f8c65b 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -23,7 +23,7 @@ end @reactant_overlay @noinline function MPI.Barrier(comm::MPI.Comm) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - return Ops.barrier() + return Ops.barrier(comm) end # TODO status not supported yet From bfac727836483b4621064282a31bf4c2ea93f347 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 4 Mar 2025 11:14:07 -0600 Subject: [PATCH 15/97] Refactor to `try_inject_to_top_block!` --- ext/ReactantMPIExt/Ops.jl | 122 +++++++++++++++++--------------------- 1 file changed, 53 insertions(+), 69 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 2f9a3fbdc5..1972c19a04 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -8,6 +8,26 @@ using Reactant.Ops: mlir_stacktrace using ..ReactantMPIExt: TracedRequest using MPI: MPI +function try_inject_to_top_block!(sym_name, code) + current_module = IR.mmodule() + fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) + + if isnothing(fn) + top_level_block = MLIR.IR.body(current_module) + code = IR.body(parse(IR.Module, code)) + + # using `collect` because if we remove the op, then the `OperationIterator` state is broken and skips ops + for op in collect(IR.OperationIterator(code)) + IR.rmfromparent!(op) + push!(top_level_block, op) + end + + return true + else + return false + end +end + # TODO we might need to have a `TracedComm` for communicators created during the compiled function # function init(; location=mlir_stacktrace("mpi.init", @__FILE__, @__LINE__)) @@ -22,30 +42,18 @@ function comm_rank(comm; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @_ sym_name = "enzymexla_wrapper_MPI_Comm_rank" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - current_module = IR.mmodule() - fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) - - if isnothing(fn) - top_level_block = MLIR.IR.body(current_module) - #! format: off - code = parse(IR.Module, """ - module { - llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 - func.func @$sym_name(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () { - %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 - %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) - func.return - } + #! format: off + try_inject_to_top_block!(sym_name, """ + module { + llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 + func.func @$sym_name(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () { + %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 + %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) + func.return } - """) |> IR.body - #! format: on - - # using `collect` because if we remove the op, then the `OperationIterator` state is broken and skips ops - for op in collect(IR.OperationIterator(code)) - IR.rmfromparent!(op) - push!(top_level_block, op) - end - end + } + """) + #! format: on # NOTE we assume here that `MPI_Comm` is of word-size comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) @@ -66,30 +74,18 @@ function comm_size(comm; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @_ sym_name = "enzymexla_wrapper_MPI_Comm_size" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - current_module = IR.mmodule() - fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) - - if isnothing(fn) - top_level_block = MLIR.IR.body(current_module) - #! format: off - code = parse(IR.Module, """ - module { - llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32 - func.func @$sym_name(%comm_ptr : !llvm.ptr, %size_ptr : !llvm.ptr) -> () { - %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 - %status = llvm.call @MPI_Comm_size(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) - func.return - } + #! format: off + try_inject_to_top_block!(sym_name, """ + module { + llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32 + func.func @$sym_name(%comm_ptr : !llvm.ptr, %size_ptr : !llvm.ptr) -> () { + %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 + %status = llvm.call @MPI_Comm_size(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) + func.return } - """) |> IR.body - #! format: on - - # using `collect` because if we remove the op, then the `OperationIterator` state is broken and skips ops - for op in collect(IR.OperationIterator(code)) - IR.rmfromparent!(op) - push!(top_level_block, op) - end - end + } + """) + #! format: on comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) value_out = Reactant.Ops.constant(fill(Cint(-1))) @@ -112,30 +108,18 @@ function barrier(comm; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LIN tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) signature = IR.Type[tensor_int_type] - current_module = IR.mmodule() - fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) - - if isnothing(fn) - top_level_block = MLIR.IR.body(current_module) - #! format: off - code = parse(IR.Module, """ - module { - llvm.func @MPI_Barrier(i32) -> i32 - func.func @$sym_name(%comm_ptr : !llvm.ptr) -> () { - %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 - %status = llvm.call @MPI_Barrier(%comm) : (i32) -> (i32) - func.return - } + #! format: off + try_inject_to_top_block!(sym_name, """ + module { + llvm.func @MPI_Barrier(i32) -> i32 + func.func @$sym_name(%comm_ptr : !llvm.ptr) -> () { + %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 + %status = llvm.call @MPI_Barrier(%comm) : (i32) -> (i32) + func.return } - """) |> IR.body - #! format: on - - # using `collect` because if we remove the op, then the `OperationIterator` state is broken and skips ops - for op in collect(IR.OperationIterator(code)) - IR.rmfromparent!(op) - push!(top_level_block, op) - end - end + } + """) + #! format: on comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) inputs = [comm.mlir_data] From 20fadc2fd3aee374b876e12cb857fb7b008195ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 5 Mar 2025 17:53:52 +0100 Subject: [PATCH 16/97] Refactor MLIR injection --- ext/ReactantMPIExt/Ops.jl | 73 ++++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 1972c19a04..c9e9753785 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -38,36 +38,69 @@ end # return mpi.finalize(; location) # end -function comm_rank(comm; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__)) +# TODO change to this kind of MLIR +# module { +# llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 +# func.func @$sym_name(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () { +# %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 +# %world_ptr = arith.constant dense<0x0asdfa> : tensor +# memref.get_global # global variable MPI_COMM_GLOBAL +# %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) +# func.return +# } +# func.func @real_$sym_name() -> tensor<> { +# %rank_ptr = stablehlo.constant dense<-1> : tensor # this is a placeholder +# %rank = enzymexla.jit_call @$sym_name(%world_ptr, %rank_ptr) { +# output_operand_alias = [ +# #stablehlo.output_operand_alias +# ] +# } +# } +# } + +function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__)) sym_name = "enzymexla_wrapper_MPI_Comm_rank" - sym_attr = IR.FlatSymbolRefAttribute(sym_name) + # sym_attr = IR.FlatSymbolRefAttribute(sym_name) + comm = MPI.COMM_WORLD #! format: off - try_inject_to_top_block!(sym_name, """ - module { - llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 - func.func @$sym_name(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () { - %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 - %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) - func.return + return Ops.hlo_call("""module { + llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 + func.func @$(sym_name)_jit(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () { + %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 + %comm = arith.constant $(Base.unsafe_convert(Cint, comm)) : i32 + %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) + func.return + } + func.func @$sym_name() -> tensor { + %rank_placeholder = stablehlo.constant dense<-1> : tensor + %rank = enzymexla.jit_call @$(sym_name)_jit(%rank_placeholder) { + output_operand_alias = [ + #stablehlo.output_operand_alias + ] } + func.return %rank : tensor } - """) + }"""; func_name=sym_name) #! format: on # NOTE we assume here that `MPI_Comm` is of word-size - comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) - value_out = Reactant.Ops.constant(fill(Cint(-1))) - inputs = IR.Value[comm.mlir_data, value_out.mlir_data] + # comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) + # value_out = Reactant.Ops.constant(fill(Cint(-1))) + # inputs = IR.Value[comm.mlir_data, value_out.mlir_data] - tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) - signature = IR.Type[tensor_int_type, tensor_int_type] + # tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) + # signature = IR.Type[tensor_int_type, tensor_int_type] - # TODO output_operand_aliases - res = IR.result( - enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location), 2 - ) - return TracedRNumber{Cint}((), res) + # # TODO output_operand_aliases + # res = IR.result( + # enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location), 2 + # ) + # return TracedRNumber{Cint}((), res) end function comm_size(comm; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__)) From 2babb79e487fd1b3afa0584f85da336f2c107873 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 5 Mar 2025 17:54:20 +0100 Subject: [PATCH 17/97] Refactor MPI constante registration --- ext/ReactantMPIExt/ReactantMPIExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index bec689e9e3..8143aeaba9 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -213,7 +213,7 @@ function __init__() if value isa Base.RefValue value = value[] end - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Cint)::Cvoid + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Ptr{Cvoid})::Cvoid end end From 2c8e95c91aa92c6534517ac9f7a6876ba7f930ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 5 Mar 2025 11:43:44 -0600 Subject: [PATCH 18/97] Fix type inference in `Ops.hlo_call` on empty args --- src/Ops.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ops.jl b/src/Ops.jl index b49ca654c0..47ed47ee0a 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1510,7 +1510,7 @@ julia> Reactant.@jit( @assert expected_type == arg_type "hlo_call: argument #$i has the wrong type (expected $expected_type, got $arg_type)" end - operands = [a.mlir_data for a in args] + operands = MLIR.IR.Value[a.mlir_data for a in args] call = MLIR.Dialects.func.call( operands; result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)], From a6738f5cd814286cc907340cdc121e8668eb3628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 5 Mar 2025 11:44:31 -0600 Subject: [PATCH 19/97] Fix MLIR of `Ops.comm_rank` --- ext/ReactantMPIExt/Ops.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index c9e9753785..8aef7d4003 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -66,10 +66,9 @@ function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LIN comm = MPI.COMM_WORLD #! format: off - return Ops.hlo_call("""module { + return Reactant.Ops.hlo_call("""module { llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 - func.func @$(sym_name)_jit(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () { - %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 + func.func @$(sym_name)_jit(%rank_ptr : !llvm.ptr) -> () { %comm = arith.constant $(Base.unsafe_convert(Cint, comm)) : i32 %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) func.return @@ -77,12 +76,12 @@ function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LIN func.func @$sym_name() -> tensor { %rank_placeholder = stablehlo.constant dense<-1> : tensor %rank = enzymexla.jit_call @$(sym_name)_jit(%rank_placeholder) { - output_operand_alias = [ + output_operand_aliases = [ #stablehlo.output_operand_alias ] - } + } : (tensor) -> (tensor) func.return %rank : tensor } }"""; func_name=sym_name) From 9710e59bbfb07de76fb64864b6cbd18e4dcf9981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 6 Mar 2025 10:10:22 -0600 Subject: [PATCH 20/97] Fix MLIR injection C-functions --- deps/ReactantExtra/API.cpp | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index a27e9c16e8..d4aeabf7fa 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -227,13 +227,28 @@ extern "C" MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc, unwrap(loc), cast(unwrap(type)), real, imag)); } -extern "C" MlirOperation mlirOperationParse(MlirContext ctx, - MlirStringRef code) { +extern "C" bool mlirOperationInject(MlirContext ctx, + MlirBlock block, + MlirStringRef code, + MlirLocation location) +{ ParserConfig config(unwrap(ctx)); - OwningOpRef owning_op = parseSourceString(unwrap(code), config); - if (!owning_op) + if (failed(parseSourceString(unwrap(code), unwrap(block), config))) + return false; + return true; +} + +extern "C" MlirOperation mlirOperationParseAppend(MlirContext ctx, + MlirBlock block, + MlirStringRef code, + MlirLocation location) { + ParserConfig config(unwrap(ctx)); + if (failed(parseSourceString(unwrap(code), unwrap(block), config))) return MlirOperation{nullptr}; - return MlirOperation{owning_op.release()}; + std::cout << "[ReactantExtra] YES?" << std::endl; + return MlirOperation{ + mlir::detail::constructContainerOpForParserIfNecessary( + unwrap(block), config.getContext(), unwrap(location)).release()}; } // TODO mlirComplexAttrGetnValue From f865e3e25fe61d10f084f1c6d2a915b8a11f5462 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 6 Mar 2025 10:11:20 -0600 Subject: [PATCH 21/97] Go back to `Cint` for registering symbols --- ext/ReactantMPIExt/ReactantMPIExt.jl | 4 ++-- src/mlir/IR/Operation.jl | 19 +++++++++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 8143aeaba9..bd16e2e4a5 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -211,9 +211,9 @@ function __init__() ] value = getproperty(MPI.API, name) if value isa Base.RefValue - value = value[] + value = value[] # TODO we need to convert this to Ptr{Cvoid} because that's what the symbol table stores end - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Ptr{Cvoid})::Cvoid + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Cint)::Cvoid end end diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index 5e0ac4d643..1bc8a0cde6 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -23,11 +23,22 @@ Base.:(==)(op::Operation, other::Operation) = API.mlirOperationEqual(op, other) Parses an operation from the string and transfers ownership to the caller. """ -Base.parse(::Core.Type{Operation}, code; context::Context=context()) = Operation( - @ccall API.mlir_c.mlirOperationParse( - context::API.MlirContext, code::API.MlirStringRef - )::API.MlirOperation +function Base.parse( + ::Core.Type{Operation}, + code; + context::Context=context(), + block=Block(), + location::Location=Location(), ) + return Operation( + @ccall API.mlir_c.mlirOperationParseAppend( + context::API.MlirContext, + block::API.MlirBlock, + code::API.MlirStringRef, + location::API.MlirLocation, + )::API.MlirOperation + ) +end """ copy(op) From 2778f1da9fc98a2c933cf357bcdc2abc99176ea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 6 Mar 2025 10:11:44 -0600 Subject: [PATCH 22/97] Add `tryinjectop!` --- ext/ReactantMPIExt/Ops.jl | 20 -------------------- src/mlir/IR/IR.jl | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 8aef7d4003..e4b94685a3 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -8,26 +8,6 @@ using Reactant.Ops: mlir_stacktrace using ..ReactantMPIExt: TracedRequest using MPI: MPI -function try_inject_to_top_block!(sym_name, code) - current_module = IR.mmodule() - fn = IR.lookup(IR.SymbolTable(IR.Operation(current_module)), sym_name) - - if isnothing(fn) - top_level_block = MLIR.IR.body(current_module) - code = IR.body(parse(IR.Module, code)) - - # using `collect` because if we remove the op, then the `OperationIterator` state is broken and skips ops - for op in collect(IR.OperationIterator(code)) - IR.rmfromparent!(op) - push!(top_level_block, op) - end - - return true - else - return false - end -end - # TODO we might need to have a `TracedComm` for communicators created during the compiled function # function init(; location=mlir_stacktrace("mpi.init", @__FILE__, @__LINE__)) diff --git a/src/mlir/IR/IR.jl b/src/mlir/IR/IR.jl index 8da48846fc..aba340ba04 100644 --- a/src/mlir/IR/IR.jl +++ b/src/mlir/IR/IR.jl @@ -134,4 +134,23 @@ function verifyall(operation::Operation; debug=false) end verifyall(module_::IR.Module; debug=false) = verifyall(Operation(module_); debug) +function tryinjectop!(sym_name, code; mod=IR.mmodule(), location=Location()) + fn = lookup(SymbolTable(Operation(mod)), sym_name) + + if isnothing(fn) + top_level_block = body(mod) + op = parse(Operation, code; block=top_level_block, location) + + # using `collect` because if we remove the op, then the `OperationIterator` state is broken and skips ops + # for op in collect(OperationIterator(code)) + # rmfromparent!(op) + # push!(top_level_block, op) + # end + + return op + else + return nothing + end +end + end # module IR From eb5da6e42757512450887440337789dc419d346e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 6 Mar 2025 10:23:38 -0600 Subject: [PATCH 23/97] Add `tryinject!`, `inject!` methods --- src/mlir/IR/IR.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/mlir/IR/IR.jl b/src/mlir/IR/IR.jl index aba340ba04..421dfc483a 100644 --- a/src/mlir/IR/IR.jl +++ b/src/mlir/IR/IR.jl @@ -134,6 +134,31 @@ function verifyall(operation::Operation; debug=false) end verifyall(module_::IR.Module; debug=false) = verifyall(Operation(module_); debug) +function tryinject!(sym_name, code; mod=IR.mmodule(), location=Location()) + fn = lookup(SymbolTable(Operation(mod)), sym_name) + + if isnothing(fn) + ctx = IR.context() + block = body(mod) + res = @ccall API.mlir_c.mlirOperationInject( + ctx::API.MlirContext, + block::API.MlirBlock, + code::API.MlirStringRef, + location::API.MlirLocation, + )::Bool + return res + else + return true + end +end + +function inject!(sym_name, code; kwargs...) + success = tryinject!(sym_name, code; kwargs...) + if !success + throw(ErrorException("Failed injecting MLIR to top-level block")) + end +end + function tryinjectop!(sym_name, code; mod=IR.mmodule(), location=Location()) fn = lookup(SymbolTable(Operation(mod)), sym_name) From 5995a7be9cd232bfdd5eefb7828e50de0b27a9b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 6 Mar 2025 10:24:30 -0600 Subject: [PATCH 24/97] Update `comm_rank` --- ext/ReactantMPIExt/Ops.jl | 56 ++++++++++++++++++++++++--------- ext/ReactantMPIExt/Overrides.jl | 2 +- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index e4b94685a3..785e93bfa8 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -45,26 +45,52 @@ function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LIN # sym_attr = IR.FlatSymbolRefAttribute(sym_name) comm = MPI.COMM_WORLD + @show IR.mmodule() + + # memref.global constant @MPI_COMM_WORLD : memref + # llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 + #! format: off - return Reactant.Ops.hlo_call("""module { - llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 + # IR.tryinjectop!("MPI_COMM_WORLD", "memref.global @MPI_COMM_WORLD : memref") + # IR.tryinjectop!("MPI_Comm_rank", "module { llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 }") + IR.inject!("$(sym_name)_jit", """ func.func @$(sym_name)_jit(%rank_ptr : !llvm.ptr) -> () { - %comm = arith.constant $(Base.unsafe_convert(Cint, comm)) : i32 + %comm_ref = memref.get_global @MPI_COMM_WORLD : memref + %comm_ptr = "enzymexla.memref2pointer"(%comm_ref) : (memref) -> (!llvm.ptr) + %comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32 %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) func.return } - func.func @$sym_name() -> tensor { - %rank_placeholder = stablehlo.constant dense<-1> : tensor - %rank = enzymexla.jit_call @$(sym_name)_jit(%rank_placeholder) { - output_operand_aliases = [ - #stablehlo.output_operand_alias - ] - } : (tensor) -> (tensor) - func.return %rank : tensor - } - }"""; func_name=sym_name) + """) + @show res + #! format: on + + # %comm_ref = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + # %comm = llvm.ptrtoint %comm_ref : !llvm.ptr to i32 + + #! format: off + # return Reactant.Ops.hlo_call("""module { + # memref.global constant @MPI_COMM_WORLD : memref + # llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 + # func.func @$(sym_name)_jit(%rank_ptr : !llvm.ptr) -> () { + # %comm_ref = memref.get_global @MPI_COMM_WORLD : memref + # %comm_ptr = "enzymexla.memref2pointer"(%comm_ref) : (memref) -> (!llvm.ptr) + # %comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32 + # %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) + # func.return + # } + # func.func @$sym_name() -> tensor { + # %rank_placeholder = stablehlo.constant dense<-1> : tensor + # %rank = enzymexla.jit_call @$(sym_name)_jit(%rank_placeholder) { + # output_operand_aliases = [ + # #stablehlo.output_operand_alias + # ] + # } : (tensor) -> (tensor) + # func.return %rank : tensor + # } + # }"""; func_name=sym_name) #! format: on # NOTE we assume here that `MPI_Comm` is of word-size diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 18d1f8c65b..81012c7f9a 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -13,7 +13,7 @@ end @reactant_overlay @noinline function MPI.Comm_rank(comm::MPI.Comm) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - return Ops.comm_rank(comm) + return Ops.comm_rank() end @reactant_overlay @noinline function MPI.Comm_size(comm::MPI.Comm) From 32b28ef35bd13c4cfb662a1810bcde62e07bc45a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 6 Mar 2025 13:25:01 -0600 Subject: [PATCH 25/97] Update `mlirOperationInject`, `mlirOperationParse` --- deps/ReactantExtra/API.cpp | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index d4aeabf7fa..d455737b8e 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -230,25 +230,28 @@ extern "C" MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc, extern "C" bool mlirOperationInject(MlirContext ctx, MlirBlock block, MlirStringRef code, - MlirLocation location) -{ - ParserConfig config(unwrap(ctx)); + MlirLocation location, + bool verify_after_parse +) { + ParserConfig config(unwrap(ctx), verify_after_parse); if (failed(parseSourceString(unwrap(code), unwrap(block), config))) return false; return true; } -extern "C" MlirOperation mlirOperationParseAppend(MlirContext ctx, - MlirBlock block, - MlirStringRef code, - MlirLocation location) { - ParserConfig config(unwrap(ctx)); +extern "C" MlirOperation mlirOperationParse(MlirContext ctx, + MlirBlock block, + MlirStringRef code, + MlirLocation location, + bool verify_after_parse +) { + ParserConfig config(unwrap(ctx), verify_after_parse); if (failed(parseSourceString(unwrap(code), unwrap(block), config))) return MlirOperation{nullptr}; - std::cout << "[ReactantExtra] YES?" << std::endl; return MlirOperation{ mlir::detail::constructContainerOpForParserIfNecessary( - unwrap(block), config.getContext(), unwrap(location)).release()}; + unwrap(block), config.getContext(), unwrap(location)).release() + }; } // TODO mlirComplexAttrGetnValue From 6e9b1c53515e684b482103d77a71d06579342bc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 6 Mar 2025 13:25:46 -0600 Subject: [PATCH 26/97] Add `verify` flag to `tryinject!`, `parse(::Operation)` --- src/mlir/IR/IR.jl | 3 ++- src/mlir/IR/Operation.jl | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/mlir/IR/IR.jl b/src/mlir/IR/IR.jl index 421dfc483a..9c7c4bbefa 100644 --- a/src/mlir/IR/IR.jl +++ b/src/mlir/IR/IR.jl @@ -134,7 +134,7 @@ function verifyall(operation::Operation; debug=false) end verifyall(module_::IR.Module; debug=false) = verifyall(Operation(module_); debug) -function tryinject!(sym_name, code; mod=IR.mmodule(), location=Location()) +function tryinject!(sym_name, code; verify=false, mod=IR.mmodule(), location=Location()) fn = lookup(SymbolTable(Operation(mod)), sym_name) if isnothing(fn) @@ -145,6 +145,7 @@ function tryinject!(sym_name, code; mod=IR.mmodule(), location=Location()) block::API.MlirBlock, code::API.MlirStringRef, location::API.MlirLocation, + verify::Bool, )::Bool return res else diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index 1bc8a0cde6..4bf02580ff 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -26,16 +26,18 @@ Parses an operation from the string and transfers ownership to the caller. function Base.parse( ::Core.Type{Operation}, code; + verify::Bool = false, context::Context=context(), block=Block(), location::Location=Location(), ) return Operation( - @ccall API.mlir_c.mlirOperationParseAppend( + @ccall API.mlir_c.mlirOperationParse( context::API.MlirContext, block::API.MlirBlock, code::API.MlirStringRef, location::API.MlirLocation, + verify::Bool, )::API.MlirOperation ) end From f4acb153ac1134d4c009b5fc468a36dfa7e317bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 6 Mar 2025 13:26:01 -0600 Subject: [PATCH 27/97] Update `Ops.comm_rank` --- ext/ReactantMPIExt/Ops.jl | 102 ++++++++++---------------------------- 1 file changed, 26 insertions(+), 76 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 785e93bfa8..f6b3b3e295 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -18,94 +18,44 @@ using MPI: MPI # return mpi.finalize(; location) # end -# TODO change to this kind of MLIR -# module { -# llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 -# func.func @$sym_name(%comm_ptr : !llvm.ptr, %rank_ptr : !llvm.ptr) -> () { -# %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 -# %world_ptr = arith.constant dense<0x0asdfa> : tensor -# memref.get_global # global variable MPI_COMM_GLOBAL -# %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) -# func.return -# } -# func.func @real_$sym_name() -> tensor<> { -# %rank_ptr = stablehlo.constant dense<-1> : tensor # this is a placeholder -# %rank = enzymexla.jit_call @$sym_name(%world_ptr, %rank_ptr) { -# output_operand_alias = [ -# #stablehlo.output_operand_alias -# ] -# } -# } -# } - function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__)) sym_name = "enzymexla_wrapper_MPI_Comm_rank" - # sym_attr = IR.FlatSymbolRefAttribute(sym_name) - comm = MPI.COMM_WORLD - - @show IR.mmodule() + sym_attr = IR.FlatSymbolRefAttribute(sym_name) - # memref.global constant @MPI_COMM_WORLD : memref - # llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 + # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast + # DONT LOAD FROM THEM! + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : i32") + IR.inject!("MPI_Comm_rank", "llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32") #! format: off - # IR.tryinjectop!("MPI_COMM_WORLD", "memref.global @MPI_COMM_WORLD : memref") - # IR.tryinjectop!("MPI_Comm_rank", "module { llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 }") - IR.inject!("$(sym_name)_jit", """ - func.func @$(sym_name)_jit(%rank_ptr : !llvm.ptr) -> () { - %comm_ref = memref.get_global @MPI_COMM_WORLD : memref - %comm_ptr = "enzymexla.memref2pointer"(%comm_ref) : (memref) -> (!llvm.ptr) + IR.inject!(sym_name, """ + func.func @$sym_name(%rank_ptr : !llvm.ptr) -> () { + %comm_ptr = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr %comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32 %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) func.return } """) - @show res - #! format: on - - # %comm_ref = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr - # %comm = llvm.ptrtoint %comm_ref : !llvm.ptr to i32 - - #! format: off - # return Reactant.Ops.hlo_call("""module { - # memref.global constant @MPI_COMM_WORLD : memref - # llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32 - # func.func @$(sym_name)_jit(%rank_ptr : !llvm.ptr) -> () { - # %comm_ref = memref.get_global @MPI_COMM_WORLD : memref - # %comm_ptr = "enzymexla.memref2pointer"(%comm_ref) : (memref) -> (!llvm.ptr) - # %comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32 - # %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) - # func.return - # } - # func.func @$sym_name() -> tensor { - # %rank_placeholder = stablehlo.constant dense<-1> : tensor - # %rank = enzymexla.jit_call @$(sym_name)_jit(%rank_placeholder) { - # output_operand_aliases = [ - # #stablehlo.output_operand_alias - # ] - # } : (tensor) -> (tensor) - # func.return %rank : tensor - # } - # }"""; func_name=sym_name) #! format: on + rank_placeholder = Reactant.Ops.constant(fill(Cint(-1))) + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + ), + ), + ]) - # NOTE we assume here that `MPI_Comm` is of word-size - # comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) - # value_out = Reactant.Ops.constant(fill(Cint(-1))) - # inputs = IR.Value[comm.mlir_data, value_out.mlir_data] - - # tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) - # signature = IR.Type[tensor_int_type, tensor_int_type] - - # # TODO output_operand_aliases - # res = IR.result( - # enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location), 2 - # ) - # return TracedRNumber{Cint}((), res) + res = IR.result( + enzymexla.jit_call( + IR.Value[rank_placeholder.mlir_data]; + fn=sym_attr, + result_0=[IR.TensorType(Int[], IR.Type(Cint))], + location, + output_operand_aliases, + ), + ) + return TracedRNumber{Cint}((), res) end function comm_size(comm; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__)) From 85a79ed12559059dff05cacd2be162ce6f7eb264 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 05:46:31 -0500 Subject: [PATCH 28/97] Update `comm_rank`, `comm_size`, `barrier`, `wait` --- ext/ReactantMPIExt/Ops.jl | 85 +++++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index f6b3b3e295..6e6e6ff098 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -24,19 +24,19 @@ function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LIN # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast # DONT LOAD FROM THEM! - IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : i32") - IR.inject!("MPI_Comm_rank", "llvm.func @MPI_Comm_rank(i32, !llvm.ptr) -> i32") + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!("MPI_Comm_rank", "llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32") #! format: off IR.inject!(sym_name, """ func.func @$sym_name(%rank_ptr : !llvm.ptr) -> () { - %comm_ptr = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr - %comm = llvm.ptrtoint %comm_ptr : !llvm.ptr to i32 - %status = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %errcode = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (!llvm.ptr, !llvm.ptr) -> (i32) func.return } """) #! format: on + rank_placeholder = Reactant.Ops.constant(fill(Cint(-1))) output_operand_aliases = IR.Attribute([ IR.Attribute( @@ -58,19 +58,21 @@ function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LIN return TracedRNumber{Cint}((), res) end -function comm_size(comm; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__)) +function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__)) sym_name = "enzymexla_wrapper_MPI_Comm_size" sym_attr = IR.FlatSymbolRefAttribute(sym_name) + # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast + # DONT LOAD FROM THEM! + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!("MPI_Comm_size", "llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32") + #! format: off - try_inject_to_top_block!(sym_name, """ - module { - llvm.func @MPI_Comm_size(i32, !llvm.ptr) -> i32 - func.func @$sym_name(%comm_ptr : !llvm.ptr, %size_ptr : !llvm.ptr) -> () { - %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 - %status = llvm.call @MPI_Comm_size(%comm, %rank_ptr) : (i32, !llvm.ptr) -> (i32) - func.return - } + IR.inject!(sym_name, """ + func.func @$sym_name(%size_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %errcode = llvm.call @MPI_Comm_rank(%comm, %size_ptr) : (!llvm.ptr, !llvm.ptr) -> (i32) + func.return } """) #! format: on @@ -89,26 +91,28 @@ function comm_size(comm; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @_ return TracedRNumber{Cint}((), res) end -function barrier(comm; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__)) +function barrier(; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__)) sym_name = "enzymexla_wrapper_MPI_Barrier" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) - signature = IR.Type[tensor_int_type] + # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast + # DONT LOAD FROM THEM! + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!("MPI_Barrier", "llvm.func @MPI_Barrier(!llvm.ptr) -> i32") #! format: off - try_inject_to_top_block!(sym_name, """ - module { - llvm.func @MPI_Barrier(i32) -> i32 - func.func @$sym_name(%comm_ptr : !llvm.ptr) -> () { - %comm = llvm.load %comm_ptr : !llvm.ptr -> i32 - %status = llvm.call @MPI_Barrier(%comm) : (i32) -> (i32) - func.return - } + IR.inject!(sym_name, """ + func.func @$sym_name() -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %status = llvm.call @MPI_Barrier(%comm) : (!llvm.ptr) -> (i32) + func.return } """) #! format: on + tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) + signature = IR.Type[tensor_int_type] + comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) inputs = [comm.mlir_data] enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location) @@ -192,12 +196,33 @@ end function wait( req::TracedRequest; location=mlir_stacktrace("mpi.wait", @__FILE__, @__LINE__) ) - # return mpi.wait(req.mlir_data; location) - inputs = IR.Value[req.mlir_data] - sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Wait") - rettype = IR.Type[] + sym_name = "enzymexla_wrapper_MPI_Wait" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast + # DONT LOAD FROM THEM! + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!("MPI_Wait", "llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32") + + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%req : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %errcode = llvm.call @MPI_Wait(%req, %comm) : (!llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on - return IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) + enzymexla.jit_call( + IR.Value[req.mlir_data]; + fn=sym_attr, + result_0=IR.Type[], + location, + output_operand_aliases=IR.Attribute(IR.Attribute[]), + ) + + return nothing end end # module From 4e3477fd8c8c822a1e69200850dcec3c8f0e38d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 07:00:36 -0500 Subject: [PATCH 29/97] Implement `Ops.allreduce` --- ext/ReactantMPIExt/Ops.jl | 273 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 6e6e6ff098..d6175b9625 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -120,6 +120,185 @@ function barrier(; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__) return nothing end +function inject_mpi_datatype(datatype) + if datatype == MPI.MPI_DATATYPE_NULL + IR.inject!( + "MPI_DATATYPE_NULL", + "llvm.mlir.global constant @MPI_DATATYPE_NULL() : !llvm.ptr", + ) + return "MPI_DATATYPE_NULL" + elseif datatype == MPI.MPI_BYTE + IR.inject!("MPI_BYTE", "llvm.mlir.global constant @MPI_BYTE() : !llvm.ptr") + return "MPI_BYTE" + elseif datatype == MPI.MPI_PACKED + IR.inject!("MPI_PACKED", "llvm.mlir.global constant @MPI_PACKED() : !llvm.ptr") + return "MPI_PACKED" + elseif datatype == MPI.MPI_CHAR + IR.inject!("MPI_CHAR", "llvm.mlir.global constant @MPI_CHAR() : !llvm.ptr") + return "MPI_CHAR" + elseif datatype == MPI.MPI_SHORT + IR.inject!("MPI_SHORT", "llvm.mlir.global constant @MPI_SHORT() : !llvm.ptr") + return "MPI_SHORT" + elseif datatype == MPI.MPI_INT + IR.inject!("MPI_INT", "llvm.mlir.global constant @MPI_INT() : !llvm.ptr") + return "MPI_INT" + elseif datatype == MPI.MPI_LONG + IR.inject!("MPI_LONG", "llvm.mlir.global constant @MPI_LONG() : !llvm.ptr") + return "MPI_LONG" + elseif datatype == MPI.MPI_FLOAT + IR.inject!("MPI_FLOAT", "llvm.mlir.global constant @MPI_FLOAT() : !llvm.ptr") + return "MPI_FLOAT" + elseif datatype == MPI.MPI_DOUBLE + IR.inject!("MPI_DOUBLE", "llvm.mlir.global constant @MPI_DOUBLE() : !llvm.ptr") + return "MPI_DOUBLE" + elseif datatype == MPI.MPI_UNSIGNED_CHAR + IR.inject!( + "MPI_UNSIGNED_CHAR", + "llvm.mlir.global constant @MPI_UNSIGNED_CHAR() : !llvm.ptr", + ) + return "MPI_UNSIGNED_CHAR" + elseif datatype == MPI.MPI_SIGNED_CHAR + IR.inject!( + "MPI_SIGNED_CHAR", "llvm.mlir.global constant @MPI_SIGNED_CHAR() : !llvm.ptr" + ) + return "MPI_SIGNED_CHAR" + elseif datatype == MPI.MPI_UNSIGNED_SHORT + IR.inject!( + "MPI_UNSIGNED_SHORT", + "llvm.mlir.global constant @MPI_UNSIGNED_SHORT() : !llvm.ptr", + ) + return "MPI_UNSIGNED_SHORT" + elseif datatype == MPI.MPI_UNSIGNED_LONG + IR.inject!( + "MPI_UNSIGNED_LONG", + "llvm.mlir.global constant @MPI_UNSIGNED_LONG() : !llvm.ptr", + ) + return "MPI_UNSIGNED_LONG" + elseif datatype == MPI.MPI_UNSIGNED + IR.inject!("MPI_UNSIGNED", "llvm.mlir.global constant @MPI_UNSIGNED() : !llvm.ptr") + return "MPI_UNSIGNED" + elseif datatype == MPI.MPI_FLOAT_INT + IR.inject!( + "MPI_FLOAT_INT", "llvm.mlir.global constant @MPI_FLOAT_INT() : !llvm.ptr" + ) + return "MPI_FLOAT_INT" + elseif datatype == MPI.MPI_DOUBLE_INT + IR.inject!( + "MPI_DOUBLE_INT", "llvm.mlir.global constant @MPI_DOUBLE_INT() : !llvm.ptr" + ) + return "MPI_DOUBLE_INT" + elseif datatype == MPI.MPI_LONG_DOUBLE_INT + IR.inject!( + "MPI_LONG_DOUBLE_INT", + "llvm.mlir.global constant @MPI_LONG_DOUBLE_INT() : !llvm.ptr", + ) + return "MPI_LONG_DOUBLE_INT" + elseif datatype == MPI.MPI_LONG_INT + IR.inject!("MPI_LONG_INT", "llvm.mlir.global constant @MPI_LONG_INT() : !llvm.ptr") + return "MPI_LONG_INT" + elseif datatype == MPI.MPI_SHORT_INT + IR.inject!( + "MPI_SHORT_INT", "llvm.mlir.global constant @MPI_SHORT_INT() : !llvm.ptr" + ) + return "MPI_SHORT_INT" + elseif datatype == MPI.MPI_UB + IR.inject!("MPI_UB", "llvm.mlir.global constant @MPI_UB() : !llvm.ptr") + return "MPI_UB" + elseif datatype == MPI.MPI_LB + IR.inject!("MPI_LB", "llvm.mlir.global constant @MPI_LB() : !llvm.ptr") + return "MPI_LB" + elseif datatype == MPI.MPI_WCHAR + IR.inject!("MPI_WCHAR", "llvm.mlir.global constant @MPI_WCHAR() : !llvm.ptr") + return "MPI_WCHAR" + elseif datatype == MPI.MPI_LONG_LONG_INT + IR.inject!( + "MPI_LONG_LONG_INT", + "llvm.mlir.global constant @MPI_LONG_LONG_INT() : !llvm.ptr", + ) + return "MPI_LONG_LONG_INT" + elseif datatype == MPI.MPI_UNSIGNED_LONG_LONG + IR.inject!( + "MPI_UNSIGNED_LONG_LONG", + "llvm.mlir.global constant @MPI_UNSIGNED_LONG_LONG() : !llvm.ptr", + ) + return "MPI_UNSIGNED_LONG_LONG" + elseif datatype == MPI.MPI_INT8_T + IR.inject!("MPI_INT8_T", "llvm.mlir.global constant @MPI_INT8_T() : !llvm.ptr") + return "MPI_INT8_T" + elseif datatype == MPI.MPI_UINT8_T + IR.inject!("MPI_UINT8_T", "llvm.mlir.global constant @MPI_UINT8_T() : !llvm.ptr") + return "MPI_UINT8_T" + elseif datatype == MPI.MPI_INT16_T + IR.inject!("MPI_INT16_T", "llvm.mlir.global constant @MPI_INT16_T() : !llvm.ptr") + return "MPI_INT16_T" + elseif datatype == MPI.MPI_UINT16_T + IR.inject!("MPI_UINT16_T", "llvm.mlir.global constant @MPI_UINT16_T() : !llvm.ptr") + return "MPI_UINT16_T" + elseif datatype == MPI.MPI_INT32_T + IR.inject!("MPI_INT32_T", "llvm.mlir.global constant @MPI_INT32_T() : !llvm.ptr") + return "MPI_INT32_T" + elseif datatype == MPI.MPI_UINT32_T + IR.inject!("MPI_UINT32_T", "llvm.mlir.global constant @MPI_UINT32_T() : !llvm.ptr") + return "MPI_UINT32_T" + elseif datatype == MPI.MPI_INT64_T + IR.inject!("MPI_INT64_T", "llvm.mlir.global constant @MPI_INT64_T() : !llvm.ptr") + return "MPI_INT64_T" + elseif datatype == MPI.MPI_UINT64_T + IR.inject!("MPI_UINT64_T", "llvm.mlir.global constant @MPI_UINT64_T() : !llvm.ptr") + return "MPI_UINT64_T" + elseif datatype == MPI.MPI_AINT + IR.inject!("MPI_AINT", "llvm.mlir.global constant @MPI_AINT() : !llvm.ptr") + return "MPI_AINT" + elseif datatype == MPI.MPI_OFFSET + IR.inject!("MPI_OFFSET", "llvm.mlir.global constant @MPI_OFFSET() : !llvm.ptr") + return "MPI_OFFSET" + elseif datatype == MPI.MPI_C_BOOL + IR.inject!("MPI_C_BOOL", "llvm.mlir.global constant @MPI_C_BOOL() : !llvm.ptr") + return "MPI_C_BOOL" + elseif datatype == MPI.MPI_C_FLOAT_COMPLEX + IR.inject!( + "MPI_C_FLOAT_COMPLEX", + "llvm.mlir.global constant @MPI_C_FLOAT_COMPLEX() : !llvm.ptr", + ) + return "MPI_C_FLOAT_COMPLEX" + elseif datatype == MPI.MPI_C_DOUBLE_COMPLEX + IR.inject!( + "MPI_C_DOUBLE_COMPLEX", + "llvm.mlir.global constant @MPI_C_DOUBLE_COMPLEX() : !llvm.ptr", + ) + return "MPI_C_DOUBLE_COMPLEX" + elseif datatype == MPI.MPI_COUNT + IR.inject!("MPI_COUNT", "llvm.mlir.global constant @MPI_COUNT() : !llvm.ptr") + return "MPI_COUNT" + else + throw(ArgumentError("Unknown MPI datatype `$datatype`")) + end +end + +function convert_julia_type_to_mpi_datatype(T::Type) + if T === Bool + MPI.C_BOOL + elseif T === Int8 + MPI.INT8_T + elseif T === Int16 + MPI.INT16_T + elseif T === Int32 + MPI.INT32_T + elseif T === Int64 + MPI.INT64_T + elseif T === Float32 + MPI.FLOAT + elseif T === Float64 + MPI.DOUBLE + elseif T === ComplexF32 + MPI.C_FLOAT_COMPLEX + elseif T === ComplexF64 + MPI.C_DOUBLE_COMPLEX + else + throw(ArgumentError("Unknown conversion from $T to a MPI_Datatype")) + end +end + # TODO emit wrapper if not found function send( buf::TracedRArray, @@ -225,4 +404,98 @@ function wait( return nothing end +function inject_mpi_op!(op) + if op == MPI.OP_NULL + IR.inject!("MPI_OP_NULL", "llvm.mlir.global constant @MPI_OP_NULL() : !llvm.ptr") + return "MPI_OP_NULL" + elseif op == MPI.MAX + IR.inject!("MPI_MAX", "llvm.mlir.global constant @MPI_MAX() : !llvm.ptr") + return "MPI_MAX" + elseif op == MPI.MIN + IR.inject!("MPI_MIN", "llvm.mlir.global constant @MPI_MIN() : !llvm.ptr") + return "MPI_MIN" + elseif op == MPI.SUM + IR.inject!("MPI_SUM", "llvm.mlir.global constant @MPI_SUM() : !llvm.ptr") + return "MPI_SUM" + elseif op == MPI.PROD + IR.inject!("MPI_PROD", "llvm.mlir.global constant @MPI_PROD() : !llvm.ptr") + return "MPI_PROD" + elseif op == MPI.LAND + IR.inject!("MPI_LAND", "llvm.mlir.global constant @MPI_LAND() : !llvm.ptr") + return "MPI_LAND" + elseif op == MPI.BAND + IR.inject!("MPI_BAND", "llvm.mlir.global constant @MPI_BAND() : !llvm.ptr") + return "MPI_BAND" + elseif op == MPI.LOR + IR.inject!("MPI_LOR", "llvm.mlir.global constant @MPI_LOR() : !llvm.ptr") + return "MPI_LOR" + elseif op == MPI.BOR + IR.inject!("MPI_BOR", "llvm.mlir.global constant @MPI_BOR() : !llvm.ptr") + return "MPI_BOR" + elseif op == MPI.LXOR + IR.inject!("MPI_LXOR", "llvm.mlir.global constant @MPI_LXOR() : !llvm.ptr") + return "MPI_LXOR" + elseif op == MPI.BXOR + IR.inject!("MPI_BXOR", "llvm.mlir.global constant @MPI_BXOR() : !llvm.ptr") + return "MPI_BXOR" + elseif op == MPI.REPLACE + IR.inject!("MPI_REPLACE", "llvm.mlir.global constant @MPI_REPLACE() : !llvm.ptr") + return "MPI_REPLACE" + elseif op == MPI.NO_OP + IR.inject!("MPI_NO_OP", "llvm.mlir.global constant @MPI_NO_OP() : !llvm.ptr") + return "MPI_NO_OP" + else + throw(ArgumentError("Unknown MPI operation `$op`")) + end +end + +function allreduce( + op, sendbuf, recvbuf; location=mlir_stacktrace("mpi.wait", @__FILE__, @__LINE__) +) + @assert Reactant.unwrapped_eltype(sendbuf) == Reactant.unwrapped_eltype(recvbuf) + @assert length(sendbuf) == recvbuf + + op_name = inject_mpi_op(op) + T = Reactant.unwrapped_eltype(sendbuf) + mpi_datatype = convert_julia_type_to_mpi_datatype(T) + mpi_datatype_name = inject_mpi_datatype(mpi_datatype) + + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!( + "MPI_Allreduce", + "llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32", + ) + + sym_name = "@enzymexla_wrapper_MPI_Allreduce_$(op_name)_$(mpi_datatype_name)" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + # TODO is okay to use `i32`? how can we use word-size value or map C's `int` to MLIR? can we use `index`? + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%sendbuf : !llvm.ptr, %recvbuf : !llvm.ptr, %count_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %op = llvm.mlir.addressof @$op_name : !llvm.ptr + %datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr + %count = llvm.load %count_ptr : !llvm.ptr -> i32 + %errcode = llvm.call @MPI_Allreduce(%sendbuf, %recvbuf, %count, %datatype, %op, %comm) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + count = Ops.constant(fill(length(sendbuf))) + + res = IR.result( + enzymexla.jit_call( + IR.Value[sendbuf.mlir_data, recvbuf.mlir_data, count.mlir_data]; + fn=sym_attr, + result_0=IR.Type[], + location, + output_operand_aliases=IR.Attribute(IR.Attribute[]), + ), + ) + + return TracedRNumber{T}((), res) +end + end # module From 9fb01635de8abcbc32851973c6d0c2a39d076d16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 07:09:59 -0500 Subject: [PATCH 30/97] Implement `Ops.send` --- ext/ReactantMPIExt/Ops.jl | 41 +++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index d6175b9625..909e392b8f 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -299,21 +299,50 @@ function convert_julia_type_to_mpi_datatype(T::Type) end end -# TODO emit wrapper if not found function send( buf::TracedRArray, tag::TracedRNumber, dest::TracedRNumber; location=mlir_stacktrace("mpi.send", @__FILE__, @__LINE__), ) - # return mpi.send(buf.mlir_data, tag.mlir_data, dest.mlir_data; location) + T = Reactant.unwrapped_eltype(buf) + mpi_datatype = convert_julia_type_to_mpi_datatype(T) + mpi_datatype_name = inject_mpi_datatype(mpi_datatype) + + sym_name = "enzymexla_wrapper_MPI_Send_$(mpi_datatype_name)" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) # TODO emit constant for size and datatype, and pass as args - inputs = IR.Value[buf.mlir_data, tag.mlir_data, dest.mlir_data] - sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Send") - rettype = IR.Type[] - return enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location) + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!( + "MPI_Send", + "llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32", + ) + + #! format: off + # TODO + IR.inject!(sym_name, """ + func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %dest_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %datatype = llvm.mlir.addressof @$(mpi_datatype_name) : !llvm.ptr + %count = llvm.load %count_ptr : !llvm.ptr -> i32 + %dest = llvm.load %dest_ptr : !llvm.ptr -> i32 + %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 + %status = llvm.call @MPI_Send(%buf, %count, %datatype, %dest, %tag, %comm) : (!llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + count = Reactant.Ops.constant(length(buf)) + + return enzymexla.jit_call( + IR.Value[buf.mlir_data, count.mlir_data, tag.mlir_data, dest.mlir_data]; + fn=sym_attr, + result_0=IR.Type[], + location, + ) end # TODO need c-function for creating MLIR `mpi.request` type? From f14dd807923d7556a44ca562cd40fb3c16a7296e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 07:10:23 -0500 Subject: [PATCH 31/97] Remove comment --- ext/ReactantMPIExt/Ops.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 909e392b8f..07f4d7f215 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -22,8 +22,6 @@ function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LIN sym_name = "enzymexla_wrapper_MPI_Comm_rank" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast - # DONT LOAD FROM THEM! IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") IR.inject!("MPI_Comm_rank", "llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32") @@ -62,8 +60,6 @@ function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LIN sym_name = "enzymexla_wrapper_MPI_Comm_size" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast - # DONT LOAD FROM THEM! IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") IR.inject!("MPI_Comm_size", "llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32") @@ -95,8 +91,6 @@ function barrier(; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__) sym_name = "enzymexla_wrapper_MPI_Barrier" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast - # DONT LOAD FROM THEM! IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") IR.inject!("MPI_Barrier", "llvm.func @MPI_Barrier(!llvm.ptr) -> i32") @@ -407,8 +401,6 @@ function wait( sym_name = "enzymexla_wrapper_MPI_Wait" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - # dirty hack: since MPI constants are i32, we pass the info as the pointer and then bitcast - # DONT LOAD FROM THEM! IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") IR.inject!("MPI_Wait", "llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32") From 14d84e2957123806beabfd4449e8c402ed66581e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 07:10:52 -0500 Subject: [PATCH 32/97] Update `Ops.wait` --- ext/ReactantMPIExt/Ops.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 07f4d7f215..1eb4b186e2 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -406,10 +406,10 @@ function wait( #! format: off IR.inject!(sym_name, """ - func.func @$sym_name(%req : !llvm.ptr) -> () { - %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr - %errcode = llvm.call @MPI_Wait(%req, %comm) : (!llvm.ptr, !llvm.ptr) -> (i32) - func.return + func.func @$sym_name(%req : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %errcode = llvm.call @MPI_Wait(%req, %comm) : (!llvm.ptr, !llvm.ptr) -> (i32) + func.return } """) #! format: on From beebfe8a5ea46f0379002288ee5877493829a54c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 07:12:28 -0500 Subject: [PATCH 33/97] Remove `comm` argument from `Comm_size`, `Barrier` overrides --- ext/ReactantMPIExt/Overrides.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 81012c7f9a..c6845f60b0 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -18,12 +18,12 @@ end @reactant_overlay @noinline function MPI.Comm_size(comm::MPI.Comm) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - return Ops.comm_size(comm) + return Ops.comm_size() end @reactant_overlay @noinline function MPI.Barrier(comm::MPI.Comm) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - return Ops.barrier(comm) + return Ops.barrier() end # TODO status not supported yet From 708a57489ac4505007d134ee3f12977fd5d50693 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 07:20:05 -0500 Subject: [PATCH 34/97] Fix `Ops.comm_size` --- ext/ReactantMPIExt/Ops.jl | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 1eb4b186e2..cbe2c9464e 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -67,22 +67,29 @@ function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LIN IR.inject!(sym_name, """ func.func @$sym_name(%size_ptr : !llvm.ptr) -> () { %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr - %errcode = llvm.call @MPI_Comm_rank(%comm, %size_ptr) : (!llvm.ptr, !llvm.ptr) -> (i32) + %errcode = llvm.call @MPI_Comm_size(%comm, %size_ptr) : (!llvm.ptr, !llvm.ptr) -> (i32) func.return } """) #! format: on - comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) - value_out = Reactant.Ops.constant(fill(Cint(-1))) - inputs = IR.Value[comm.mlir_data, value_out.mlir_data] - - tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) - signature = IR.Type[tensor_int_type, tensor_int_type] + size_placeholder = Reactant.Ops.constant(fill(Cint(-1))) + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + ), + ), + ]) - # TODO output_operand_aliases res = IR.result( - enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location), 2 + enzymexla.jit_call( + IR.Value[size_placeholder.mlir_data]; + fn=sym_attr, + result_0=[IR.TensorType(Int[], IR.Type(Cint))], + output_operand_aliases, + location, + ), ) return TracedRNumber{Cint}((), res) end From da59f4a4c41185aff4e75ca2c7ea076ef42d4a14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 07:22:45 -0500 Subject: [PATCH 35/97] Fix `Ops.barrier` --- ext/ReactantMPIExt/Ops.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index cbe2c9464e..acb3a40548 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -111,12 +111,10 @@ function barrier(; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__) """) #! format: on - tensor_int_type = IR.TensorType(Int[], IR.Type(Cint)) - signature = IR.Type[tensor_int_type] - - comm = Reactant.Ops.constant(Base.unsafe_convert(Cint, comm)) - inputs = [comm.mlir_data] - enzymexla.jit_call(inputs; fn=sym_attr, result_0=signature, location) + output_operand_aliases = IR.Attribute(IR.Attribute[]) + enzymexla.jit_call( + IR.Value[]; fn=sym_attr, result_0=IR.Type[], output_operand_aliases, location + ) return nothing end From b6a9cdfd82dbb8c22ff8413720a0f7e3a420ab4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 08:06:39 -0500 Subject: [PATCH 36/97] Fixes and renames --- ext/ReactantMPIExt/Ops.jl | 112 +++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 51 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index acb3a40548..491ee1a080 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -119,154 +119,154 @@ function barrier(; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__) return nothing end -function inject_mpi_datatype(datatype) - if datatype == MPI.MPI_DATATYPE_NULL +function inject_mpi_datatype!(datatype) + if datatype == MPI.DATATYPE_NULL IR.inject!( "MPI_DATATYPE_NULL", "llvm.mlir.global constant @MPI_DATATYPE_NULL() : !llvm.ptr", ) return "MPI_DATATYPE_NULL" - elseif datatype == MPI.MPI_BYTE + elseif datatype == MPI.BYTE IR.inject!("MPI_BYTE", "llvm.mlir.global constant @MPI_BYTE() : !llvm.ptr") return "MPI_BYTE" - elseif datatype == MPI.MPI_PACKED - IR.inject!("MPI_PACKED", "llvm.mlir.global constant @MPI_PACKED() : !llvm.ptr") - return "MPI_PACKED" - elseif datatype == MPI.MPI_CHAR + # elseif datatype == MPI.PACKED + # IR.inject!("MPI_PACKED", "llvm.mlir.global constant @MPI_PACKED() : !llvm.ptr") + # return "MPI_PACKED" + elseif datatype == MPI.CHAR IR.inject!("MPI_CHAR", "llvm.mlir.global constant @MPI_CHAR() : !llvm.ptr") return "MPI_CHAR" - elseif datatype == MPI.MPI_SHORT + elseif datatype == MPI.SHORT IR.inject!("MPI_SHORT", "llvm.mlir.global constant @MPI_SHORT() : !llvm.ptr") return "MPI_SHORT" - elseif datatype == MPI.MPI_INT + elseif datatype == MPI.INT IR.inject!("MPI_INT", "llvm.mlir.global constant @MPI_INT() : !llvm.ptr") return "MPI_INT" - elseif datatype == MPI.MPI_LONG + elseif datatype == MPI.LONG IR.inject!("MPI_LONG", "llvm.mlir.global constant @MPI_LONG() : !llvm.ptr") return "MPI_LONG" - elseif datatype == MPI.MPI_FLOAT + elseif datatype == MPI.FLOAT IR.inject!("MPI_FLOAT", "llvm.mlir.global constant @MPI_FLOAT() : !llvm.ptr") return "MPI_FLOAT" - elseif datatype == MPI.MPI_DOUBLE + elseif datatype == MPI.DOUBLE IR.inject!("MPI_DOUBLE", "llvm.mlir.global constant @MPI_DOUBLE() : !llvm.ptr") return "MPI_DOUBLE" - elseif datatype == MPI.MPI_UNSIGNED_CHAR + elseif datatype == MPI.UNSIGNED_CHAR IR.inject!( "MPI_UNSIGNED_CHAR", "llvm.mlir.global constant @MPI_UNSIGNED_CHAR() : !llvm.ptr", ) return "MPI_UNSIGNED_CHAR" - elseif datatype == MPI.MPI_SIGNED_CHAR + elseif datatype == MPI.SIGNED_CHAR IR.inject!( "MPI_SIGNED_CHAR", "llvm.mlir.global constant @MPI_SIGNED_CHAR() : !llvm.ptr" ) return "MPI_SIGNED_CHAR" - elseif datatype == MPI.MPI_UNSIGNED_SHORT + elseif datatype == MPI.UNSIGNED_SHORT IR.inject!( "MPI_UNSIGNED_SHORT", "llvm.mlir.global constant @MPI_UNSIGNED_SHORT() : !llvm.ptr", ) return "MPI_UNSIGNED_SHORT" - elseif datatype == MPI.MPI_UNSIGNED_LONG + elseif datatype == MPI.UNSIGNED_LONG IR.inject!( "MPI_UNSIGNED_LONG", "llvm.mlir.global constant @MPI_UNSIGNED_LONG() : !llvm.ptr", ) return "MPI_UNSIGNED_LONG" - elseif datatype == MPI.MPI_UNSIGNED + elseif datatype == MPI.UNSIGNED IR.inject!("MPI_UNSIGNED", "llvm.mlir.global constant @MPI_UNSIGNED() : !llvm.ptr") return "MPI_UNSIGNED" - elseif datatype == MPI.MPI_FLOAT_INT + elseif datatype == MPI.FLOAT_INT IR.inject!( "MPI_FLOAT_INT", "llvm.mlir.global constant @MPI_FLOAT_INT() : !llvm.ptr" ) return "MPI_FLOAT_INT" - elseif datatype == MPI.MPI_DOUBLE_INT + elseif datatype == MPI.DOUBLE_INT IR.inject!( "MPI_DOUBLE_INT", "llvm.mlir.global constant @MPI_DOUBLE_INT() : !llvm.ptr" ) return "MPI_DOUBLE_INT" - elseif datatype == MPI.MPI_LONG_DOUBLE_INT + elseif datatype == MPI.LONG_DOUBLE_INT IR.inject!( "MPI_LONG_DOUBLE_INT", "llvm.mlir.global constant @MPI_LONG_DOUBLE_INT() : !llvm.ptr", ) return "MPI_LONG_DOUBLE_INT" - elseif datatype == MPI.MPI_LONG_INT + elseif datatype == MPI.LONG_INT IR.inject!("MPI_LONG_INT", "llvm.mlir.global constant @MPI_LONG_INT() : !llvm.ptr") return "MPI_LONG_INT" - elseif datatype == MPI.MPI_SHORT_INT + elseif datatype == MPI.SHORT_INT IR.inject!( "MPI_SHORT_INT", "llvm.mlir.global constant @MPI_SHORT_INT() : !llvm.ptr" ) return "MPI_SHORT_INT" - elseif datatype == MPI.MPI_UB + elseif datatype == MPI.UB IR.inject!("MPI_UB", "llvm.mlir.global constant @MPI_UB() : !llvm.ptr") return "MPI_UB" - elseif datatype == MPI.MPI_LB + elseif datatype == MPI.LB IR.inject!("MPI_LB", "llvm.mlir.global constant @MPI_LB() : !llvm.ptr") return "MPI_LB" - elseif datatype == MPI.MPI_WCHAR + elseif datatype == MPI.WCHAR IR.inject!("MPI_WCHAR", "llvm.mlir.global constant @MPI_WCHAR() : !llvm.ptr") return "MPI_WCHAR" - elseif datatype == MPI.MPI_LONG_LONG_INT + elseif datatype == MPI.LONG_LONG_INT IR.inject!( "MPI_LONG_LONG_INT", "llvm.mlir.global constant @MPI_LONG_LONG_INT() : !llvm.ptr", ) return "MPI_LONG_LONG_INT" - elseif datatype == MPI.MPI_UNSIGNED_LONG_LONG + elseif datatype == MPI.UNSIGNED_LONG_LONG IR.inject!( "MPI_UNSIGNED_LONG_LONG", "llvm.mlir.global constant @MPI_UNSIGNED_LONG_LONG() : !llvm.ptr", ) return "MPI_UNSIGNED_LONG_LONG" - elseif datatype == MPI.MPI_INT8_T + elseif datatype == MPI.INT8_T IR.inject!("MPI_INT8_T", "llvm.mlir.global constant @MPI_INT8_T() : !llvm.ptr") return "MPI_INT8_T" - elseif datatype == MPI.MPI_UINT8_T + elseif datatype == MPI.UINT8_T IR.inject!("MPI_UINT8_T", "llvm.mlir.global constant @MPI_UINT8_T() : !llvm.ptr") return "MPI_UINT8_T" - elseif datatype == MPI.MPI_INT16_T + elseif datatype == MPI.INT16_T IR.inject!("MPI_INT16_T", "llvm.mlir.global constant @MPI_INT16_T() : !llvm.ptr") return "MPI_INT16_T" - elseif datatype == MPI.MPI_UINT16_T + elseif datatype == MPI.UINT16_T IR.inject!("MPI_UINT16_T", "llvm.mlir.global constant @MPI_UINT16_T() : !llvm.ptr") return "MPI_UINT16_T" - elseif datatype == MPI.MPI_INT32_T + elseif datatype == MPI.INT32_T IR.inject!("MPI_INT32_T", "llvm.mlir.global constant @MPI_INT32_T() : !llvm.ptr") return "MPI_INT32_T" - elseif datatype == MPI.MPI_UINT32_T + elseif datatype == MPI.UINT32_T IR.inject!("MPI_UINT32_T", "llvm.mlir.global constant @MPI_UINT32_T() : !llvm.ptr") return "MPI_UINT32_T" - elseif datatype == MPI.MPI_INT64_T + elseif datatype == MPI.INT64_T IR.inject!("MPI_INT64_T", "llvm.mlir.global constant @MPI_INT64_T() : !llvm.ptr") return "MPI_INT64_T" - elseif datatype == MPI.MPI_UINT64_T + elseif datatype == MPI.UINT64_T IR.inject!("MPI_UINT64_T", "llvm.mlir.global constant @MPI_UINT64_T() : !llvm.ptr") return "MPI_UINT64_T" - elseif datatype == MPI.MPI_AINT + elseif datatype == MPI.AINT IR.inject!("MPI_AINT", "llvm.mlir.global constant @MPI_AINT() : !llvm.ptr") return "MPI_AINT" - elseif datatype == MPI.MPI_OFFSET + elseif datatype == MPI.OFFSET IR.inject!("MPI_OFFSET", "llvm.mlir.global constant @MPI_OFFSET() : !llvm.ptr") return "MPI_OFFSET" - elseif datatype == MPI.MPI_C_BOOL + elseif datatype == MPI.C_BOOL IR.inject!("MPI_C_BOOL", "llvm.mlir.global constant @MPI_C_BOOL() : !llvm.ptr") return "MPI_C_BOOL" - elseif datatype == MPI.MPI_C_FLOAT_COMPLEX + elseif datatype == MPI.C_FLOAT_COMPLEX IR.inject!( "MPI_C_FLOAT_COMPLEX", "llvm.mlir.global constant @MPI_C_FLOAT_COMPLEX() : !llvm.ptr", ) return "MPI_C_FLOAT_COMPLEX" - elseif datatype == MPI.MPI_C_DOUBLE_COMPLEX + elseif datatype == MPI.C_DOUBLE_COMPLEX IR.inject!( "MPI_C_DOUBLE_COMPLEX", "llvm.mlir.global constant @MPI_C_DOUBLE_COMPLEX() : !llvm.ptr", ) return "MPI_C_DOUBLE_COMPLEX" - elseif datatype == MPI.MPI_COUNT + elseif datatype == MPI.COUNT IR.inject!("MPI_COUNT", "llvm.mlir.global constant @MPI_COUNT() : !llvm.ptr") return "MPI_COUNT" else @@ -306,7 +306,7 @@ function send( ) T = Reactant.unwrapped_eltype(buf) mpi_datatype = convert_julia_type_to_mpi_datatype(T) - mpi_datatype_name = inject_mpi_datatype(mpi_datatype) + mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) sym_name = "enzymexla_wrapper_MPI_Send_$(mpi_datatype_name)" sym_attr = IR.FlatSymbolRefAttribute(sym_name) @@ -475,16 +475,16 @@ function inject_mpi_op!(op) end end -function allreduce( +function allreduce!( op, sendbuf, recvbuf; location=mlir_stacktrace("mpi.wait", @__FILE__, @__LINE__) ) @assert Reactant.unwrapped_eltype(sendbuf) == Reactant.unwrapped_eltype(recvbuf) - @assert length(sendbuf) == recvbuf + @assert length(sendbuf) == length(recvbuf) - op_name = inject_mpi_op(op) + op_name = inject_mpi_op!(op) T = Reactant.unwrapped_eltype(sendbuf) mpi_datatype = convert_julia_type_to_mpi_datatype(T) - mpi_datatype_name = inject_mpi_datatype(mpi_datatype) + mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") IR.inject!( @@ -492,7 +492,7 @@ function allreduce( "llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32", ) - sym_name = "@enzymexla_wrapper_MPI_Allreduce_$(op_name)_$(mpi_datatype_name)" + sym_name = "enzymexla_wrapper_MPI_Allreduce_$(op_name)_$(mpi_datatype_name)" sym_attr = IR.FlatSymbolRefAttribute(sym_name) # TODO is okay to use `i32`? how can we use word-size value or map C's `int` to MLIR? can we use `index`? @@ -509,19 +509,29 @@ function allreduce( """) #! format: on - count = Ops.constant(fill(length(sendbuf))) + count = Reactant.Ops.constant(fill(length(sendbuf))) + + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 1, 0, C_NULL + ), + ), + ]) res = IR.result( enzymexla.jit_call( IR.Value[sendbuf.mlir_data, recvbuf.mlir_data, count.mlir_data]; fn=sym_attr, - result_0=IR.Type[], + result_0=IR.Type[Reactant.Ops.mlir_type(typeof(recvbuf), size(recvbuf))], location, - output_operand_aliases=IR.Attribute(IR.Attribute[]), + output_operand_aliases, ), ) - return TracedRNumber{T}((), res) + recvbuf.mlir_data = res + + return recvbuf end end # module From c207e02bda59feb3e43881ea14b6a1c0c3b4fdd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 08:06:50 -0500 Subject: [PATCH 37/97] Override `MPI.Allreduce!` --- ext/ReactantMPIExt/Overrides.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index c6845f60b0..7b7f59e122 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -115,3 +115,8 @@ function MPI.Irecv!( req.mlir_data = gen_req.mlir_data return req end + +function MPI.Allreduce!(sendbuf::TracedRArray, recvbuf::TracedRArray, op, comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.allreduce!(op, sendbuf, recvbuf) +end From 287bd327e8f6b031e69eb99a43c487d64359077d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 09:25:30 -0500 Subject: [PATCH 38/97] Fix conversion of MPI constants to word-size type --- ext/ReactantMPIExt/ReactantMPIExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index bd16e2e4a5..49e00e4a07 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -213,7 +213,8 @@ function __init__() if value isa Base.RefValue value = value[] # TODO we need to convert this to Ptr{Cvoid} because that's what the symbol table stores end - @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Cint)::Cvoid + value = convert(Int, value) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Int)::Cvoid end end From 0c32d6549b1481df0f0ac3267e5bfe5486dee2e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 12:28:00 -0500 Subject: [PATCH 39/97] Comment unused MPI datatypes --- ext/ReactantMPIExt/Ops.jl | 60 +++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 491ee1a080..e29178b7cd 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -176,36 +176,36 @@ function inject_mpi_datatype!(datatype) elseif datatype == MPI.UNSIGNED IR.inject!("MPI_UNSIGNED", "llvm.mlir.global constant @MPI_UNSIGNED() : !llvm.ptr") return "MPI_UNSIGNED" - elseif datatype == MPI.FLOAT_INT - IR.inject!( - "MPI_FLOAT_INT", "llvm.mlir.global constant @MPI_FLOAT_INT() : !llvm.ptr" - ) - return "MPI_FLOAT_INT" - elseif datatype == MPI.DOUBLE_INT - IR.inject!( - "MPI_DOUBLE_INT", "llvm.mlir.global constant @MPI_DOUBLE_INT() : !llvm.ptr" - ) - return "MPI_DOUBLE_INT" - elseif datatype == MPI.LONG_DOUBLE_INT - IR.inject!( - "MPI_LONG_DOUBLE_INT", - "llvm.mlir.global constant @MPI_LONG_DOUBLE_INT() : !llvm.ptr", - ) - return "MPI_LONG_DOUBLE_INT" - elseif datatype == MPI.LONG_INT - IR.inject!("MPI_LONG_INT", "llvm.mlir.global constant @MPI_LONG_INT() : !llvm.ptr") - return "MPI_LONG_INT" - elseif datatype == MPI.SHORT_INT - IR.inject!( - "MPI_SHORT_INT", "llvm.mlir.global constant @MPI_SHORT_INT() : !llvm.ptr" - ) - return "MPI_SHORT_INT" - elseif datatype == MPI.UB - IR.inject!("MPI_UB", "llvm.mlir.global constant @MPI_UB() : !llvm.ptr") - return "MPI_UB" - elseif datatype == MPI.LB - IR.inject!("MPI_LB", "llvm.mlir.global constant @MPI_LB() : !llvm.ptr") - return "MPI_LB" + # elseif datatype == MPI.FLOAT_INT + # IR.inject!( + # "MPI_FLOAT_INT", "llvm.mlir.global constant @MPI_FLOAT_INT() : !llvm.ptr" + # ) + # return "MPI_FLOAT_INT" + # elseif datatype == MPI.DOUBLE_INT + # IR.inject!( + # "MPI_DOUBLE_INT", "llvm.mlir.global constant @MPI_DOUBLE_INT() : !llvm.ptr" + # ) + # return "MPI_DOUBLE_INT" + # elseif datatype == MPI.LONG_DOUBLE_INT + # IR.inject!( + # "MPI_LONG_DOUBLE_INT", + # "llvm.mlir.global constant @MPI_LONG_DOUBLE_INT() : !llvm.ptr", + # ) + # return "MPI_LONG_DOUBLE_INT" + # elseif datatype == MPI.LONG_INT + # IR.inject!("MPI_LONG_INT", "llvm.mlir.global constant @MPI_LONG_INT() : !llvm.ptr") + # return "MPI_LONG_INT" + # elseif datatype == MPI.SHORT_INT + # IR.inject!( + # "MPI_SHORT_INT", "llvm.mlir.global constant @MPI_SHORT_INT() : !llvm.ptr" + # ) + # return "MPI_SHORT_INT" + # elseif datatype == MPI.UB + # IR.inject!("MPI_UB", "llvm.mlir.global constant @MPI_UB() : !llvm.ptr") + # return "MPI_UB" + # elseif datatype == MPI.LB + # IR.inject!("MPI_LB", "llvm.mlir.global constant @MPI_LB() : !llvm.ptr") + # return "MPI_LB" elseif datatype == MPI.WCHAR IR.inject!("MPI_WCHAR", "llvm.mlir.global constant @MPI_WCHAR() : !llvm.ptr") return "MPI_WCHAR" From 7eb4107ae2434a9f8abf19c9a46b2c47f14fa479 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 12:29:01 -0500 Subject: [PATCH 40/97] small fixes --- ext/ReactantMPIExt/Ops.jl | 11 +++++++---- ext/ReactantMPIExt/ReactantMPIExt.jl | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index e29178b7cd..573deccf37 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -4,7 +4,7 @@ using Reactant: MLIR using Reactant.MLIR: IR using Reactant.MLIR.IR: @mlir_str using Reactant.MLIR.Dialects: mpi, func, llvm, enzymexla -using Reactant.Ops: mlir_stacktrace +using Reactant.Ops: mlir_stacktrace, mlir_type using ..ReactantMPIExt: TracedRequest using MPI: MPI @@ -328,7 +328,7 @@ function send( %count = llvm.load %count_ptr : !llvm.ptr -> i32 %dest = llvm.load %dest_ptr : !llvm.ptr -> i32 %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 - %status = llvm.call @MPI_Send(%buf, %count, %datatype, %dest, %tag, %comm) : (!llvm.ptr) -> (i32) + %errcode = llvm.call @MPI_Send(%buf, %count, %datatype, %dest, %tag, %comm) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> (i32) func.return } """) @@ -336,12 +336,15 @@ function send( count = Reactant.Ops.constant(length(buf)) - return enzymexla.jit_call( + enzymexla.jit_call( IR.Value[buf.mlir_data, count.mlir_data, tag.mlir_data, dest.mlir_data]; fn=sym_attr, - result_0=IR.Type[], + result_0=IR.Type[mlir_type(buf), mlir_type(count), mlir_type(tag), mlir_type(dest)], + output_operand_aliases=IR.Attribute(IR.Attribute[]), location, ) + + return nothing end # TODO need c-function for creating MLIR `mpi.request` type? diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 49e00e4a07..318107deb1 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -211,7 +211,7 @@ function __init__() ] value = getproperty(MPI.API, name) if value isa Base.RefValue - value = value[] # TODO we need to convert this to Ptr{Cvoid} because that's what the symbol table stores + value = value[] end value = convert(Int, value) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Int)::Cvoid From 19c0eca670c9a89274841ba5366e83adf5a0d537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 11 Mar 2025 12:29:17 -0500 Subject: [PATCH 41/97] Implement `MPI.Recv!` --- ext/ReactantMPIExt/Ops.jl | 62 ++++++++++++++++++++++++++++----- ext/ReactantMPIExt/Overrides.jl | 36 ++++++++++++------- 2 files changed, 78 insertions(+), 20 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 573deccf37..48197e7ece 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -369,20 +369,66 @@ function isend( end function recv!( - ref::TracedRArray, + recvbuf::TracedRArray, tag::TracedRNumber, src::TracedRNumber; location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__), ) - # return mpi.recv(ref.mlir_data, tag.mlir_data, src.mlir_data; location) + T = Reactant.unwrapped_eltype(recvbuf) + mpi_datatype = convert_julia_type_to_mpi_datatype(T) + mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) - # TODO emit constant for size and datatype, and pass as args - inputs = IR.Value[ref.mlir_data, tag.mlir_data, src.mlir_data] - sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Recv") - rettype = IR.Type[] + sym_name = "enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name)" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) - IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) - return ref + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!( + "MPI_STATUS_IGNORE", "llvm.mlir.global constant @MPI_STATUS_IGNORE() : !llvm.ptr" + ) + IR.inject!( + "MPI_Recv", + "llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32", + ) + + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr + %status = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr + %count = llvm.load %count_ptr : !llvm.ptr -> i32 + %source = llvm.load %source_ptr : !llvm.ptr -> i32 + %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 + %errcode = llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + count = Reactant.Ops.constant(fill(length(recvbuf))) + + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + ), + ), + ]) + + res = IR.result( + enzymexla.jit_call( + IR.Value[recvbuf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data]; + fn=sym_attr, + result_0=[mlir_type(recvbuf)], + output_operand_aliases, + location, + ), + 1, + ) + + recvbuf.mlir_data = res + + return recvbuf end # TODO need c-function for creating MLIR `mpi.request` type? diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 7b7f59e122..16b8ee332d 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -75,22 +75,34 @@ function MPI.Isend( return req end -# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` +function MPI.Recv!(buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm) + tag = Reactant.Ops.constant(tag) + source = Reactant.Ops.constant(source) + return MPI.Recv!(buf, source, tag, comm) +end + function MPI.Recv!( - recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm, status + recvbuf::TracedRArray, + source::Integer, + tag::Integer, + comm::MPI.Comm, + ::Type{MPI.API.MPI_Status}, ) - @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - @assert isnothing(status) "Status not supported yet" - - tag = if !(tag isa TracedRNumber) - Reactant.Ops.constant(tag) - end + return MPI.Recv!(recvbuf, source, tag, comm) +end - source = if !(source isa TracedRNumber) - Reactant.Ops.constant(source) - end +function MPI.Recv!( + recvbuf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm, ::Nothing +) + return MPI.Recv!(recvbuf, source, tag, comm) +end - return Ops.recv(recvbuf, tag, source) +# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` +function MPI.Recv!( + recvbuf::TracedRArray, source::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm +) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.recv!(recvbuf, tag, source) end # TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` From 946339c297a2e8505e92a85abc680f44dee36892 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 16 Mar 2025 09:01:28 +0100 Subject: [PATCH 42/97] Test MPI --- test/Project.toml | 2 ++ test/integration/mpi.jl | 20 ++++++++++++++++++++ test/runtests.jl | 4 ++++ 3 files changed, 26 insertions(+) create mode 100644 test/integration/mpi.jl diff --git a/test/Project.toml b/test/Project.toml index 8d46640284..9109d4a3d7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -16,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" @@ -48,6 +49,7 @@ LinearAlgebra = "1.10" Lux = "1.4.1" LuxLib = "1.3" MLUtils = "0.4.4" +MPI = "0.20" NNlib = "0.9.26" OffsetArrays = "1" OneHotArrays = "0.2.6" diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl new file mode 100644 index 0000000000..60228eb155 --- /dev/null +++ b/test/integration/mpi.jl @@ -0,0 +1,20 @@ +using Test, MPI, Reactant + +@testset "Comm_rank" begin + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + @test rank == @jit MPI.Comm_rank(comm) +end + +@testset "Comm_size" begin + comm = MPI.COMM_WORLD + nranks = MPI.Comm_size(comm) + @test nranks == @jit MPI.Comm_size(comm) +end + +@testset "Allreduce" begin + comm = MPI.COMM_WORLD + x = ConcreteRArray(fill(1)) + nranks = MPI.Comm_size(comm) + @test nranks == @jit MPI.Allreduce(x, MPI.SUM, MPI.COMM_WORLD) +end diff --git a/test/runtests.jl b/test/runtests.jl index 1e848d1715..3266037107 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -77,6 +77,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "SpecialFunctions" include("integration/special_functions.jl") @safetestset "Random" include("integration/random.jl") @safetestset "Python" include("integration/python.jl") + @safetestset "MPI" begin + nranks = 2 + run(`$(mpiexec()) -nranks $n $(Base.julia_cmd()) integration/mpi.jl`) + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" From 42bf6b26df34852d1589a70032c22e6fd276ba32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Sun, 16 Mar 2025 09:08:57 +0100 Subject: [PATCH 43/97] Update src/mlir/IR/Operation.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/mlir/IR/Operation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index d1be396db7..95d46e33d1 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -26,7 +26,7 @@ Parses an operation from the string and transfers ownership to the caller. function Base.parse( ::Core.Type{Operation}, code; - verify::Bool = false, + verify::Bool=false, context::Context=context(), block=Block(), location::Location=Location(), From fad18452b7cd76583081d5e26a93ea7c839dd3c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 16 Mar 2025 09:22:32 +0100 Subject: [PATCH 44/97] Fix `mpiexec` symbol import --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 17b968168a..c9081a283c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,6 +75,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Random" include("integration/random.jl") @safetestset "Python" include("integration/python.jl") @safetestset "MPI" begin + using MPI nranks = 2 run(`$(mpiexec()) -nranks $n $(Base.julia_cmd()) integration/mpi.jl`) end From 0359559e229c98775c818959b1370d9e844465e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 16 Mar 2025 09:32:08 +0100 Subject: [PATCH 45/97] Fix typo --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index c9081a283c..2302d91618 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -77,7 +77,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "MPI" begin using MPI nranks = 2 - run(`$(mpiexec()) -nranks $n $(Base.julia_cmd()) integration/mpi.jl`) + run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`) end end From dd327ecdb817eb9dfc6c33754d30c7385768fc2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 16 Mar 2025 09:51:11 +0100 Subject: [PATCH 46/97] Init and Finalize on MPI tests --- test/integration/mpi.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index 60228eb155..8fef6a8f5a 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -1,5 +1,7 @@ using Test, MPI, Reactant +MPI.Init() + @testset "Comm_rank" begin comm = MPI.COMM_WORLD rank = MPI.Comm_rank(comm) @@ -18,3 +20,5 @@ end nranks = MPI.Comm_size(comm) @test nranks == @jit MPI.Allreduce(x, MPI.SUM, MPI.COMM_WORLD) end + +MPI.Finalize() From 113b30fcc0612545d14975863167262f44b7a421 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 18 May 2025 17:27:59 -0500 Subject: [PATCH 47/97] Fix changes introduces in "feat: IR inject functions (#1217)" --- src/mlir/IR/IR.jl | 45 ---------------------------------------- src/mlir/IR/Operation.jl | 24 --------------------- 2 files changed, 69 deletions(-) diff --git a/src/mlir/IR/IR.jl b/src/mlir/IR/IR.jl index 96732cfec5..3d429aeaf5 100644 --- a/src/mlir/IR/IR.jl +++ b/src/mlir/IR/IR.jl @@ -170,49 +170,4 @@ function tryinjectop!(sym_name, code; mod=IR.mmodule(), location=Location()) end end -function tryinject!(sym_name, code; verify=false, mod=IR.mmodule(), location=Location()) - fn = lookup(SymbolTable(Operation(mod)), sym_name) - - if isnothing(fn) - ctx = IR.context() - block = body(mod) - res = @ccall API.mlir_c.mlirOperationInject( - ctx::API.MlirContext, - block::API.MlirBlock, - code::API.MlirStringRef, - location::API.MlirLocation, - verify::Bool, - )::Bool - return res - else - return true - end -end - -function inject!(sym_name, code; kwargs...) - success = tryinject!(sym_name, code; kwargs...) - if !success - throw(ErrorException("Failed injecting MLIR to top-level block")) - end -end - -function tryinjectop!(sym_name, code; mod=IR.mmodule(), location=Location()) - fn = lookup(SymbolTable(Operation(mod)), sym_name) - - if isnothing(fn) - top_level_block = body(mod) - op = parse(Operation, code; block=top_level_block, location) - - # using `collect` because if we remove the op, then the `OperationIterator` state is broken and skips ops - # for op in collect(OperationIterator(code)) - # rmfromparent!(op) - # push!(top_level_block, op) - # end - - return op - else - return nothing - end -end - end # module IR diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index 82c9479f98..e9f763ac58 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -18,30 +18,6 @@ function Base.unsafe_convert(::Core.Type{API.MlirOperation}, operation::Operatio end Base.:(==)(op::Operation, other::Operation) = API.mlirOperationEqual(op, other) -""" - parse(::Type{Operation}, code; context=context()) - -Parses an operation from the string and transfers ownership to the caller. -""" -function Base.parse( - ::Core.Type{Operation}, - code; - verify::Bool=false, - context::Context=context(), - block=Block(), - location::Location=Location(), -) - return Operation( - @ccall API.mlir_c.mlirOperationParse( - context::API.MlirContext, - block::API.MlirBlock, - code::API.MlirStringRef, - location::API.MlirLocation, - verify::Bool, - )::API.MlirOperation - ) -end - """ copy(op) From 15335a9a8581c219c63f8f076d303bbbe14a4279 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 12 Jul 2025 13:58:01 +0200 Subject: [PATCH 48/97] return errcode for `send` and `recv!` --- ext/ReactantMPIExt/Ops.jl | 71 ++++++++++++++++++++++----------- ext/ReactantMPIExt/Overrides.jl | 2 + 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 48197e7ece..31ead8682a 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -322,29 +322,41 @@ function send( #! format: off # TODO IR.inject!(sym_name, """ - func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %dest_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { + func.func @$sym_name(%errcode : !llvm.ptr, %buf : !llvm.ptr, %count_ptr : !llvm.ptr, %dest_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr %datatype = llvm.mlir.addressof @$(mpi_datatype_name) : !llvm.ptr %count = llvm.load %count_ptr : !llvm.ptr -> i32 %dest = llvm.load %dest_ptr : !llvm.ptr -> i32 %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 - %errcode = llvm.call @MPI_Send(%buf, %count, %datatype, %dest, %tag, %comm) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> (i32) + %res = llvm.call @MPI_Send(%buf, %count, %datatype, %dest, %tag, %comm) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> (i32) + llvm.store %res, %errcode : i32, !llvm.ptr func.return } """) #! format: on - count = Reactant.Ops.constant(length(buf)) + count = Reactant.Ops.constant(Int32(length(buf))) + errcode = Reactant.Ops.constant(fill(Cint(0))) - enzymexla.jit_call( - IR.Value[buf.mlir_data, count.mlir_data, tag.mlir_data, dest.mlir_data]; + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + ), + ), + ]) + + ret = enzymexla.jit_call( + IR.Value[ + errcode.mlir_data, buf.mlir_data, count.mlir_data, dest.mlir_data, tag.mlir_data + ]; fn=sym_attr, - result_0=IR.Type[mlir_type(buf), mlir_type(count), mlir_type(tag), mlir_type(dest)], - output_operand_aliases=IR.Attribute(IR.Attribute[]), + result_0=IR.Type[mlir_type(errcode)], + output_operand_aliases=output_operand_aliases, location, ) - - return nothing + errcode.mlir_data = IR.result(ret) + return errcode end # TODO need c-function for creating MLIR `mpi.request` type? @@ -392,43 +404,54 @@ function recv!( #! format: off IR.inject!(sym_name, """ - func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { + func.func @$sym_name(%errcode : !llvm.ptr, %buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr %datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr %status = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr %count = llvm.load %count_ptr : !llvm.ptr -> i32 %source = llvm.load %source_ptr : !llvm.ptr -> i32 %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 - %errcode = llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32) + %res = llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32) + llvm.store %res, %errcode : i32, !llvm.ptr func.return } """) #! format: on - count = Reactant.Ops.constant(fill(length(recvbuf))) + count = Reactant.Ops.constant(Int32(length(recvbuf))) + errcode = Reactant.Ops.constant(fill(Cint(0))) output_operand_aliases = IR.Attribute([ IR.Attribute( MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + MLIR.IR.context(), 1, Ref{Int64}(0), 0, 0, C_NULL + ), + ), + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 1, Ref{Int64}(1), 1, 0, C_NULL ), ), ]) - res = IR.result( - enzymexla.jit_call( - IR.Value[recvbuf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data]; - fn=sym_attr, - result_0=[mlir_type(recvbuf)], - output_operand_aliases, - location, - ), - 1, + ret = enzymexla.jit_call( + IR.Value[ + errcode.mlir_data, + recvbuf.mlir_data, + count.mlir_data, + src.mlir_data, + tag.mlir_data, + ]; + fn=sym_attr, + result_0=[mlir_type(errcode), mlir_type(recvbuf)], + output_operand_aliases, + location, ) - recvbuf.mlir_data = res + errcode.mlir_data = IR.result(ret, 1) + recvbuf.mlir_data = IR.result(ret, 2) - return recvbuf + return errcode, recvbuf end # TODO need c-function for creating MLIR `mpi.request` type? diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 16b8ee332d..b208550dea 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -39,6 +39,8 @@ function MPI.Send(buf::TracedRArray, dest::Integer, tag::Integer, comm::MPI.Comm end # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` +# NOTE unlike MPI.jl's `MPI.Send`, we return the errcode to generate the data dep +# that prevents it from being optimized away function MPI.Send( buf::TracedRArray, dest::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm ) From 3391df8a6aa5d6980b9befd324c01ca28a6c85b1 Mon Sep 17 00:00:00 2001 From: romanlee Date: Wed, 16 Jul 2025 14:06:43 -0700 Subject: [PATCH 49/97] Add tests for Send and Recv! --- test/integration/mpi.jl | 60 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index 8fef6a8f5a..effceda98d 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -21,4 +21,64 @@ end @test nranks == @jit MPI.Allreduce(x, MPI.SUM, MPI.COMM_WORLD) end +@testset "Send, Recv!" begin + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + nranks = MPI.Comm_size(comm) + + if nranks < 2 + @warn "Need at least 2 MPI processes for send tests. Skipping." + return + end + + # test MPI.jl Send / Reactant Recv + @testset "MPI.jl Send / Reactant Recv!" begin + send_buf = fill(1) + tag = 43 + if rank == 0 + MPI.Send(send_buf, comm; dest=1, tag=tag) + @test true + elseif rank == 1 + recv_buf = ConcreteRArray(fill(12)) + source = 0 + @jit MPI.Recv!(recv_buf, source, tag, comm) + @test recv_buf == send_buf + end + end + + # test Reactant Send / MPI.jl Recv + @testset "Reactant Send / MPI.jl Recv!" begin + send_buf = ConcreteRArray(fill(1)) + tag = 43 + if rank == 0 + dest = 1 + @jit MPI.Send(send_buf, dest, tag, comm) + @test true + elseif rank == 1 + recv_buf = fill(12) + MPI.Recv!(recv_buf, comm; source=0, tag=tag) + @test recv_buf == send_buf + end + end + + # test Reactant Send/Recv + @testset "Reactant Send / Recv!" begin + send_buf = ConcreteRArray(fill(1)) + tag = 43 + if rank == 0 + # Send: pass on cpu, pass on gpu + dest = 1 + @jit MPI.Send(send_buf, dest, tag, comm) + @test true # Send completed + elseif rank == 1 + # hang on cpu + # segfault on gpu upon trying to reference res + recv_buf = ConcreteRArray(fill(12)) + src = 0 + @jit MPI.Recv!(recv_buf, src, tag, comm) + @test recv_buf == send_buf + end + end +end + MPI.Finalize() From 89f4ad4a734fedff6f3cdd78b544471a22f59196 Mon Sep 17 00:00:00 2001 From: romanlee Date: Fri, 18 Jul 2025 15:28:55 -0700 Subject: [PATCH 50/97] Chipping away at isend --- ext/ReactantMPIExt/Ops.jl | 80 ++++++++++++++++++++++++++++----- ext/ReactantMPIExt/Overrides.jl | 70 ++++++++++++++++++++--------- 2 files changed, 118 insertions(+), 32 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 31ead8682a..ff45cdbdf7 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -320,7 +320,8 @@ function send( ) #! format: off - # TODO + # int MPI_Send(const void* buf, int count, MPI_Datatype datatype, + # int dest, int tag, MPI_Comm comm) IR.inject!(sym_name, """ func.func @$sym_name(%errcode : !llvm.ptr, %buf : !llvm.ptr, %count_ptr : !llvm.ptr, %dest_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr @@ -363,21 +364,78 @@ end function isend( buf::TracedRArray, tag::TracedRNumber, - dest::TracedRNumber; + dest::TracedRNumber, + req::TracedRequest; # TODO ROMAN shouldn't we pass this in?? location=mlir_stacktrace("mpi.isend", @__FILE__, @__LINE__), ) - # return TracedRequest( - # IR.result(mpi.isend(buf.mlir_data, tag.mlir_data, dest.mlir_data; location)) - # ) + T = Reactant.unwrapped_eltype(buf) + mpi_datatype = convert_julia_type_to_mpi_datatype(T) + mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) - # TODO emit constant for size and datatype, and pass as args - inputs = IR.Value[buf.mlir_data, tag.mlir_data, dest.mlir_data] - sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Isend") - rettype = IR.Type[] # TODO return MPI_Request -> use i32 or opaque? + sym_name = "enzymexla_wrapper_MPI_Isend_$(mpi_datatype_name)" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) - return TracedRequest( - IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!( + "MPI_Isend", + "llvm.func @MPI_Isend(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32", ) + + #! format: off + # int MPI_Isend(const void* buf, int count, MPI_Datatype datatype, + # int dest, int tag, MPI_Comm comm, MPI_Request* request) + IR.inject!(sym_name, """ + func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %dest_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr, %req_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %datatype = llvm.mlir.addressof @$(mpi_datatype_name) : !llvm.ptr + %count = llvm.load %count_ptr : !llvm.ptr -> i32 + %dest = llvm.load %dest_ptr : !llvm.ptr -> i32 + %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 + %res = llvm.call @MPI_Isend(%buf, %count, %datatype, %dest, %tag, %comm, %req_ptr) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + count = Reactant.Ops.constant(Int32(length(buf))) + + # TODO ROMAN need to use output_operand_aliases to get the reutnr values? + # how the hell does this thing work + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + ), + ), + ]) + + ret = enzymexla.jit_call( + IR.Value[ + buf.mlir_data, count.mlir_data, dest.mlir_data, tag.mlir_data, req.mlir_data + ]; + fn=sym_attr, + result_0=IR.Type[mlir_type(req)], # TODO ROMAN: need to define a function mlir_type(::TracedRequest)? + output_operand_aliases=output_operand_aliases, + location, + ) + + # TODO ROMAN how to return the request? + return TracedRequest( IR.result(ret) ) + + + # # Sergio's stuff + # # return TracedRequest( + # # IR.result(mpi.isend(buf.mlir_data, tag.mlir_data, dest.mlir_data; location)) + # # ) + + # # TODO emit constant for size and datatype, and pass as args + # inputs = IR.Value[buf.mlir_data, tag.mlir_data, dest.mlir_data] + # sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Isend") + # rettype = IR.Type[] # TODO return MPI_Request -> use i32 or opaque? + + # return TracedRequest( + # IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) + # ) end function recv!( diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index b208550dea..97b31bdd55 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -32,7 +32,12 @@ function MPI.Wait(req::TracedRequest) end # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` -function MPI.Send(buf::TracedRArray, dest::Integer, tag::Integer, comm::MPI.Comm) +function MPI.Send( + buf::TracedRArray, + dest::Integer, + tag::Integer, + comm::MPI.Comm +) tag = Reactant.Ops.constant(tag) dest = Reactant.Ops.constant(dest) return MPI.Send(buf, dest, tag, comm) @@ -42,40 +47,63 @@ end # NOTE unlike MPI.jl's `MPI.Send`, we return the errcode to generate the data dep # that prevents it from being optimized away function MPI.Send( - buf::TracedRArray, dest::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm + buf::TracedRArray, + dest::TracedRNumber, + tag::TracedRNumber, + comm::MPI.Comm ) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" return Ops.send(buf, tag, dest) end +# TODO should we error if other `AbstractRequest` types are passed in? +function MPI.Isend( + buf::TracedRArray, + dest::Integer, + tag::Integer, + comm::MPI.Comm, + req::TracedRequest +) + tag = Reactant.Ops.constant(tag) + dest = Reactant.Ops.constant(dest) + + gen_req = MPI.Isend(buf, dest, tag, comm) + req.mlir_data = gen_req.mlir_data + return req +end + # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` function MPI.Isend( buf::TracedRArray, - dest::Union{T,TracedRNumber{T}}, - tag::Union{T,TracedRNumber{T}}, + dest::TracedRNumber, + tag::TracedRNumber, comm::MPI.Comm, -) where {T<:Integer} +) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - tag = if !(tag isa TracedRNumber) - Reactant.Ops.constant(tag) - end - - dest = if !(dest isa TracedRNumber) - Reactant.Ops.constant(dest) - end - return Ops.isend(buf, tag, dest) end -# TODO should we error if other `AbstractRequest` types are passed in? -function MPI.Isend( - buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm, req::TracedRequest -) - gen_req = MPI.Isend(buf, dest, tag, comm) - req.mlir_data = gen_req.mlir_data - return req -end +# TODO ROMAN do we want to use this signature, or the one in MPI.Send? Either way, they should be the same +# # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` +# function MPI.Isend( +# buf::TracedRArray, +# dest::Union{T,TracedRNumber{T}}, +# tag::Union{T,TracedRNumber{T}}, +# comm::MPI.Comm, +# ) where {T<:Integer} +# @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + +# tag = if !(tag isa TracedRNumber) +# Reactant.Ops.constant(tag) +# end + +# dest = if !(dest isa TracedRNumber) +# Reactant.Ops.constant(dest) +# end + +# return Ops.isend(buf, tag, dest) +# end function MPI.Recv!(buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm) tag = Reactant.Ops.constant(tag) From b0cdb4208759875f560ff0ce4fe45c74b2311c4c Mon Sep 17 00:00:00 2001 From: romanlee Date: Mon, 21 Jul 2025 14:41:12 -0700 Subject: [PATCH 51/97] Finish off Ops.isend. Could be working now --- ext/ReactantMPIExt/Ops.jl | 33 ++++++++------------------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index ff45cdbdf7..4517751518 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -364,8 +364,7 @@ end function isend( buf::TracedRArray, tag::TracedRNumber, - dest::TracedRNumber, - req::TracedRequest; # TODO ROMAN shouldn't we pass this in?? + dest::TracedRNumber; location=mlir_stacktrace("mpi.isend", @__FILE__, @__LINE__), ) T = Reactant.unwrapped_eltype(buf) @@ -381,9 +380,9 @@ function isend( "llvm.func @MPI_Isend(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32", ) - #! format: off # int MPI_Isend(const void* buf, int count, MPI_Datatype datatype, # int dest, int tag, MPI_Comm comm, MPI_Request* request) + #! format: off IR.inject!(sym_name, """ func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %dest_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr, %req_ptr : !llvm.ptr) -> () { %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr @@ -398,44 +397,28 @@ function isend( #! format: on count = Reactant.Ops.constant(Int32(length(buf))) + request = Reactant.Ops.constant(Int64(-1)) - # TODO ROMAN need to use output_operand_aliases to get the reutnr values? - # how the hell does this thing work output_operand_aliases = IR.Attribute([ IR.Attribute( MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + MLIR.IR.context(), 0, C_NULL, 4, 0, C_NULL ), ), ]) ret = enzymexla.jit_call( IR.Value[ - buf.mlir_data, count.mlir_data, dest.mlir_data, tag.mlir_data, req.mlir_data + buf.mlir_data, count.mlir_data, dest.mlir_data, tag.mlir_data, request.mlir_data ]; fn=sym_attr, - result_0=IR.Type[mlir_type(req)], # TODO ROMAN: need to define a function mlir_type(::TracedRequest)? + result_0=IR.Type[mlir_type(request)], output_operand_aliases=output_operand_aliases, location, ) - # TODO ROMAN how to return the request? - return TracedRequest( IR.result(ret) ) - - - # # Sergio's stuff - # # return TracedRequest( - # # IR.result(mpi.isend(buf.mlir_data, tag.mlir_data, dest.mlir_data; location)) - # # ) - - # # TODO emit constant for size and datatype, and pass as args - # inputs = IR.Value[buf.mlir_data, tag.mlir_data, dest.mlir_data] - # sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Isend") - # rettype = IR.Type[] # TODO return MPI_Request -> use i32 or opaque? - - # return TracedRequest( - # IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) - # ) + request.mlir_data = IR.result(ret) + return request end function recv!( From 3c132d4bb60bec1d7770e977fc2bdcd339bfc084 Mon Sep 17 00:00:00 2001 From: romanlee Date: Mon, 21 Jul 2025 14:51:42 -0700 Subject: [PATCH 52/97] Add default constructor for TracedRequest --- ext/ReactantMPIExt/ReactantMPIExt.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 318107deb1..5428b7dc5f 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -222,6 +222,8 @@ struct TracedRequest <: MPI.AbstractRequest mlir_data::Union{Nothing,Reactant.MLIR.IR.Value} end +TracedRequest() = TracedRequest(nothing) + include("Ops.jl") include("Overrides.jl") From 32cd9224e34a6392641f78aeb1317f1cf0c7dc7e Mon Sep 17 00:00:00 2001 From: romanlee Date: Mon, 21 Jul 2025 16:16:59 -0700 Subject: [PATCH 53/97] Clean up mpi.jl tests --- test/integration/mpi.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index effceda98d..e467151146 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -26,18 +26,12 @@ end rank = MPI.Comm_rank(comm) nranks = MPI.Comm_size(comm) - if nranks < 2 - @warn "Need at least 2 MPI processes for send tests. Skipping." - return - end - # test MPI.jl Send / Reactant Recv @testset "MPI.jl Send / Reactant Recv!" begin send_buf = fill(1) tag = 43 if rank == 0 MPI.Send(send_buf, comm; dest=1, tag=tag) - @test true elseif rank == 1 recv_buf = ConcreteRArray(fill(12)) source = 0 @@ -53,7 +47,6 @@ end if rank == 0 dest = 1 @jit MPI.Send(send_buf, dest, tag, comm) - @test true elseif rank == 1 recv_buf = fill(12) MPI.Recv!(recv_buf, comm; source=0, tag=tag) @@ -69,7 +62,6 @@ end # Send: pass on cpu, pass on gpu dest = 1 @jit MPI.Send(send_buf, dest, tag, comm) - @test true # Send completed elseif rank == 1 # hang on cpu # segfault on gpu upon trying to reference res From 1bc9ddc21246ce7d65c521e58c4cb462bdf45a65 Mon Sep 17 00:00:00 2001 From: romanlee Date: Tue, 22 Jul 2025 16:43:24 -0700 Subject: [PATCH 54/97] TracedRequest must be mutable struct? Also make TracedRequest optional arg in MPI.Isend --- ext/ReactantMPIExt/Ops.jl | 1 + ext/ReactantMPIExt/Overrides.jl | 10 +++++----- ext/ReactantMPIExt/ReactantMPIExt.jl | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 4517751518..b2abe61436 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -417,6 +417,7 @@ function isend( location, ) + # return TracedRNumber request.mlir_data = IR.result(ret) return request end diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 97b31bdd55..c096e54f16 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -62,14 +62,14 @@ function MPI.Isend( dest::Integer, tag::Integer, comm::MPI.Comm, - req::TracedRequest + request::TracedRequest=TracedRequest() ) tag = Reactant.Ops.constant(tag) dest = Reactant.Ops.constant(dest) - gen_req = MPI.Isend(buf, dest, tag, comm) - req.mlir_data = gen_req.mlir_data - return req + gen_request = MPI.Isend(buf, dest, tag, comm) + request.mlir_data = gen_request.mlir_data + return request end # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` @@ -84,7 +84,7 @@ function MPI.Isend( return Ops.isend(buf, tag, dest) end -# TODO ROMAN do we want to use this signature, or the one in MPI.Send? Either way, they should be the same +# TODO possible to use this signature? As is, ambiguous with the ones defined by MPI.jl # # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` # function MPI.Isend( # buf::TracedRArray, diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 5428b7dc5f..dc8725b698 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -218,7 +218,7 @@ function __init__() end end -struct TracedRequest <: MPI.AbstractRequest +mutable struct TracedRequest <: MPI.AbstractRequest mlir_data::Union{Nothing,Reactant.MLIR.IR.Value} end From 1218d5c3a6544ebf4f4b6f81cce72d589d496f2c Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 14 Aug 2025 16:24:54 -0700 Subject: [PATCH 55/97] Scratching together an implementation of make_tracer for TracedRequest --- ext/ReactantMPIExt/ReactantMPIExt.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index dc8725b698..e62519be7f 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -218,12 +218,40 @@ function __init__() end end + mutable struct TracedRequest <: MPI.AbstractRequest mlir_data::Union{Nothing,Reactant.MLIR.IR.Value} end + TracedRequest() = TracedRequest(nothing) + +Base.@nospecializeinfer function Reactant.make_tracer( + seen, + @nospecialize(prev::TracedRequest), + @nospecialize(path), + mode; + tobatch=nothing, + toscalar=false, + @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(runtime = nothing), + kwargs..., +) + if mode == Reactant.TracedToConcrete + haskey(seen, prev) && return seen[prev]::MPI.Request + if !Sharding.is_sharded(sharding) + res = MPI.Request + else + error("you probably shouldnt be using sharding and mpi...") + end + seen[prev] = res + return res + end + throw("Trace mode $mode not implemented") +end + + include("Ops.jl") include("Overrides.jl") From 626566b4f0d6dff25970f86687101383bb66e897 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 14 Aug 2025 17:25:37 -0700 Subject: [PATCH 56/97] Additions to TracedRequest and associated functions/uses --- ext/ReactantMPIExt/Overrides.jl | 2 +- ext/ReactantMPIExt/ReactantMPIExt.jl | 25 +++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index c096e54f16..cffa9994d9 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -62,7 +62,7 @@ function MPI.Isend( dest::Integer, tag::Integer, comm::MPI.Comm, - request::TracedRequest=TracedRequest() + request::TracedRequest=TracedRequest((), nothing) ) tag = Reactant.Ops.constant(tag) dest = Reactant.Ops.constant(dest) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index e62519be7f..fa7af86d5d 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -220,11 +220,24 @@ end mutable struct TracedRequest <: MPI.AbstractRequest + paths::Tuple mlir_data::Union{Nothing,Reactant.MLIR.IR.Value} -end + function TracedRequest( + paths::Tuple, mlir_data::Union{Nothing,Reactant.MLIR.IR.Value} + ) + if !isnothing(mlir_data) + @assert size(Reactant.MLIR.IR.type(mlir_data)) == () + end + return new(paths, mlir_data) + end +end -TracedRequest() = TracedRequest(nothing) +# # presumably gonna have to write these at som epoint too +# get_mlir_data(x::TracedRequest) = x.mlir_data +# set_mlir_data!(x::TracedRequest, data) = (x.mlir_data = data; return x) +get_paths(x::TracedRequest) = x.paths +set_paths!(x::TracedRequest, paths) = (x.paths = paths; return x) Base.@nospecializeinfer function Reactant.make_tracer( @@ -238,6 +251,14 @@ Base.@nospecializeinfer function Reactant.make_tracer( @nospecialize(runtime = nothing), kwargs..., ) + println("IN make_tracer{TracedRequest}") + if mode == Reactant.NoStopTracedTrack + set_paths!(prev, (get_paths(prev)..., path)) + if !haskey(seen, prev) + seen[prev] = prev # don't return! + end + return prev + end if mode == Reactant.TracedToConcrete haskey(seen, prev) && return seen[prev]::MPI.Request if !Sharding.is_sharded(sharding) From 27031380948aa05e0e87aa339e80b82c08a8c5ae Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Mon, 18 Aug 2025 15:24:03 -0700 Subject: [PATCH 57/97] compile_xla/mlir might be working properly with TracedRequest now --- ext/ReactantMPIExt/ReactantMPIExt.jl | 27 ++++++++++++++++++++------- src/Reactant.jl | 3 ++- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index fa7af86d5d..68153e4f20 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -36,6 +36,10 @@ function Distributed.get_local_process_id(::Distributed.MPIEnvDetector) end function __init__() + # TODO improve this, temporary hack + # when you fix it, remember to possibly make TracedType const again + Reactant.TracedType = Union{Reactant.TracedRArray,Reactant.TracedRNumber,Reactant.MissingTracedValue,TracedRequest} + # TODO maybe it's more efficient if we use `RTLD_NOW` instead of `RTLD_LAZY`? libmpi_handle = Libdl.dlopen(MPI.API.libmpi, RTLD_LAZY | RTLD_GLOBAL) @@ -233,11 +237,21 @@ mutable struct TracedRequest <: MPI.AbstractRequest end end -# # presumably gonna have to write these at som epoint too -# get_mlir_data(x::TracedRequest) = x.mlir_data -# set_mlir_data!(x::TracedRequest, data) = (x.mlir_data = data; return x) -get_paths(x::TracedRequest) = x.paths -set_paths!(x::TracedRequest, paths) = (x.paths = paths; return x) +function Base.show(io::IOty, X::TracedRequest) where {IOty<:Union{IO,IOContext}} + return print(io, "TracedRequest(", X.paths, ")") +end + +Reactant.TracedUtils.get_mlir_data(x::TracedRequest) = x.mlir_data +# Reactant.TracedUtils.set_mlir_data!(x::TracedRequest, data) = (x.mlir_data = data; return x) # guess this one was never needed...? Maybe it happens in create_result +Reactant.TracedUtils.get_paths(x::TracedRequest) = x.paths +Reactant.TracedUtils.set_paths!(x::TracedRequest, paths) = (x.paths = paths; return x) + +# TODO not sure how to implement this for TracedRequest +# probably just want to hardcode the types and dims? +function Reactant.Ops.mlir_type(x::TracedRequest)::MLIR.IR.Type + # return MLIR.IR.TensorType(collect(Int, size(x)), MLIR.IR.Type(unwrapped_eltype(x))) + return MLIR.IR.TensorType(collect(Int, ()), MLIR.IR.Type(Int64)) +end Base.@nospecializeinfer function Reactant.make_tracer( @@ -251,9 +265,8 @@ Base.@nospecializeinfer function Reactant.make_tracer( @nospecialize(runtime = nothing), kwargs..., ) - println("IN make_tracer{TracedRequest}") if mode == Reactant.NoStopTracedTrack - set_paths!(prev, (get_paths(prev)..., path)) + Reactant.TracedUtils.set_paths!(prev, (Reactant.TracedUtils.get_paths(prev)..., path)) if !haskey(seen, prev) seen[prev] = prev # don't return! end diff --git a/src/Reactant.jl b/src/Reactant.jl index 02893f9516..6c559a8916 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -181,7 +181,8 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") -const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} +# const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} +TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} include("ControlFlow.jl") include("Tracing.jl") From 0bac4dca94f49b227db7e7b767c097eaa09da5ae Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Wed, 27 Aug 2025 17:36:15 -0700 Subject: [PATCH 58/97] Add create_result for MPI.Request, fix make_tracer --- ext/ReactantMPIExt/ReactantMPIExt.jl | 48 +++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 68153e4f20..24e7465e5d 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -242,7 +242,7 @@ function Base.show(io::IOty, X::TracedRequest) where {IOty<:Union{IO,IOContext}} end Reactant.TracedUtils.get_mlir_data(x::TracedRequest) = x.mlir_data -# Reactant.TracedUtils.set_mlir_data!(x::TracedRequest, data) = (x.mlir_data = data; return x) # guess this one was never needed...? Maybe it happens in create_result +# Reactant.TracedUtils.set_mlir_data!(x::TracedRequest, data) = (x.mlir_data = data; return x) # TODO maybe we should be using this in Isend to set the mlir data... Reactant.TracedUtils.get_paths(x::TracedRequest) = x.paths Reactant.TracedUtils.set_paths!(x::TracedRequest, paths) = (x.paths = paths; return x) @@ -253,7 +253,6 @@ function Reactant.Ops.mlir_type(x::TracedRequest)::MLIR.IR.Type return MLIR.IR.TensorType(collect(Int, ()), MLIR.IR.Type(Int64)) end - Base.@nospecializeinfer function Reactant.make_tracer( seen, @nospecialize(prev::TracedRequest), @@ -275,9 +274,9 @@ Base.@nospecializeinfer function Reactant.make_tracer( if mode == Reactant.TracedToConcrete haskey(seen, prev) && return seen[prev]::MPI.Request if !Sharding.is_sharded(sharding) - res = MPI.Request + res = MPI.Request() else - error("you probably shouldnt be using sharding and mpi...") + error("Attempting to use sharding and MPI simultaneously") end seen[prev] = res return res @@ -285,6 +284,47 @@ Base.@nospecializeinfer function Reactant.make_tracer( throw("Trace mode $mode not implemented") end +function Reactant.Compiler.create_result( + tocopy::MPI.Request, + path, + result_stores, + path_to_shard_info, + to_unreshard_results, + unresharded_code::Vector{Expr}, + unresharded_arrays_cache, + used_shardinfo, + result_cache, + var_idx, + resultgen_code, +) + if !haskey(result_cache, tocopy) + sym = Symbol("result", var_idx[]) + var_idx[] += 1 + + @assert haskey(result_stores, path) + restore = result_stores[path] + delete!(result_stores, path) + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) + error("Attempting to use sharding and MPI simultaneously") + else + # TODO + # restore = result_buffer1 = linearized_results[1] = result of XLA.executesharded() + # but what is actually returned from XLA.executesharded? + # Same thing as returned from MPI.Isend (ie, TracedRequest)? + result = :(MPI.Request($restore)) + end + push!( + resultgen_code, + quote + $sym = $result + end, + ) + result_cache[tocopy] = sym + end + + return result_cache[tocopy] +end + include("Ops.jl") include("Overrides.jl") From a8643c314211babd5a861da4b9c7c0d0f951bc70 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Fri, 29 Aug 2025 17:06:08 -0700 Subject: [PATCH 59/97] Comment these out, for now we'll assume no Request needs to cross the compile boundary --- ext/ReactantMPIExt/ReactantMPIExt.jl | 164 ++++++++++++++------------- 1 file changed, 83 insertions(+), 81 deletions(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 24e7465e5d..f69573a622 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -242,88 +242,90 @@ function Base.show(io::IOty, X::TracedRequest) where {IOty<:Union{IO,IOContext}} end Reactant.TracedUtils.get_mlir_data(x::TracedRequest) = x.mlir_data -# Reactant.TracedUtils.set_mlir_data!(x::TracedRequest, data) = (x.mlir_data = data; return x) # TODO maybe we should be using this in Isend to set the mlir data... -Reactant.TracedUtils.get_paths(x::TracedRequest) = x.paths -Reactant.TracedUtils.set_paths!(x::TracedRequest, paths) = (x.paths = paths; return x) +# Reactant.TracedUtils.set_mlir_data!(x::TracedRequest, data) = (x.mlir_data = data; return x) -# TODO not sure how to implement this for TracedRequest -# probably just want to hardcode the types and dims? -function Reactant.Ops.mlir_type(x::TracedRequest)::MLIR.IR.Type - # return MLIR.IR.TensorType(collect(Int, size(x)), MLIR.IR.Type(unwrapped_eltype(x))) - return MLIR.IR.TensorType(collect(Int, ()), MLIR.IR.Type(Int64)) -end - -Base.@nospecializeinfer function Reactant.make_tracer( - seen, - @nospecialize(prev::TracedRequest), - @nospecialize(path), - mode; - tobatch=nothing, - toscalar=false, - @nospecialize(sharding = Sharding.NoSharding()), - @nospecialize(runtime = nothing), - kwargs..., -) - if mode == Reactant.NoStopTracedTrack - Reactant.TracedUtils.set_paths!(prev, (Reactant.TracedUtils.get_paths(prev)..., path)) - if !haskey(seen, prev) - seen[prev] = prev # don't return! - end - return prev - end - if mode == Reactant.TracedToConcrete - haskey(seen, prev) && return seen[prev]::MPI.Request - if !Sharding.is_sharded(sharding) - res = MPI.Request() - else - error("Attempting to use sharding and MPI simultaneously") - end - seen[prev] = res - return res - end - throw("Trace mode $mode not implemented") -end - -function Reactant.Compiler.create_result( - tocopy::MPI.Request, - path, - result_stores, - path_to_shard_info, - to_unreshard_results, - unresharded_code::Vector{Expr}, - unresharded_arrays_cache, - used_shardinfo, - result_cache, - var_idx, - resultgen_code, -) - if !haskey(result_cache, tocopy) - sym = Symbol("result", var_idx[]) - var_idx[] += 1 - - @assert haskey(result_stores, path) - restore = result_stores[path] - delete!(result_stores, path) - if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) - error("Attempting to use sharding and MPI simultaneously") - else - # TODO - # restore = result_buffer1 = linearized_results[1] = result of XLA.executesharded() - # but what is actually returned from XLA.executesharded? - # Same thing as returned from MPI.Isend (ie, TracedRequest)? - result = :(MPI.Request($restore)) - end - push!( - resultgen_code, - quote - $sym = $result - end, - ) - result_cache[tocopy] = sym - end - - return result_cache[tocopy] -end +# # May need these later, but for now we assume that no request needs to pass the compile boundary +# Reactant.TracedUtils.get_paths(x::TracedRequest) = x.paths +# Reactant.TracedUtils.set_paths!(x::TracedRequest, paths) = (x.paths = paths; return x) +# +# # TODO not sure how to implement this for TracedRequest +# # probably just want to hardcode the types and dims? +# function Reactant.Ops.mlir_type(x::TracedRequest)::MLIR.IR.Type +# # return MLIR.IR.TensorType(collect(Int, size(x)), MLIR.IR.Type(unwrapped_eltype(x))) +# return MLIR.IR.TensorType(collect(Int, ()), MLIR.IR.Type(Int64)) +# end +# +# Base.@nospecializeinfer function Reactant.make_tracer( +# seen, +# @nospecialize(prev::TracedRequest), +# @nospecialize(path), +# mode; +# tobatch=nothing, +# toscalar=false, +# @nospecialize(sharding = Sharding.NoSharding()), +# @nospecialize(runtime = nothing), +# kwargs..., +# ) +# if mode == Reactant.NoStopTracedTrack +# Reactant.TracedUtils.set_paths!(prev, (Reactant.TracedUtils.get_paths(prev)..., path)) +# if !haskey(seen, prev) +# seen[prev] = prev # don't return! +# end +# return prev +# end +# if mode == Reactant.TracedToConcrete +# haskey(seen, prev) && return seen[prev]::MPI.Request +# if !Sharding.is_sharded(sharding) +# res = MPI.Request() +# else +# error("Attempting to use sharding and MPI simultaneously") +# end +# seen[prev] = res +# return res +# end +# throw("Trace mode $mode not implemented") +# end +# +# function Reactant.Compiler.create_result( +# tocopy::MPI.Request, +# path, +# result_stores, +# path_to_shard_info, +# to_unreshard_results, +# unresharded_code::Vector{Expr}, +# unresharded_arrays_cache, +# used_shardinfo, +# result_cache, +# var_idx, +# resultgen_code, +# ) +# if !haskey(result_cache, tocopy) +# sym = Symbol("result", var_idx[]) +# var_idx[] += 1 +# +# @assert haskey(result_stores, path) +# restore = result_stores[path] +# delete!(result_stores, path) +# if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) +# error("Attempting to use sharding and MPI simultaneously") +# else +# # TODO +# # restore = result_buffer1 = linearized_results[1] = result of XLA.executesharded() +# # but what is actually returned from XLA.executesharded? +# # Same thing as returned from MPI.Isend (ie, TracedRequest)? +# result = :(MPI.Request($restore)) +# end +# push!( +# resultgen_code, +# quote +# $sym = $result +# end, +# ) +# result_cache[tocopy] = sym +# end +# +# return result_cache[tocopy] +# end include("Ops.jl") From 8d2f1688254e12eb7918dc6d0bf90a8c922df192 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Fri, 29 Aug 2025 17:28:44 -0700 Subject: [PATCH 60/97] Add Irecv! and friends. Seems to potentially be working --- ext/ReactantMPIExt/Ops.jl | 87 ++++++++++++++++++++++++++++----- ext/ReactantMPIExt/Overrides.jl | 38 +++++++------- 2 files changed, 95 insertions(+), 30 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index b2abe61436..96750c4ff5 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -496,22 +496,83 @@ function recv!( return errcode, recvbuf end -# TODO need c-function for creating MLIR `mpi.request` type? +# # TODO need c-function for creating MLIR `mpi.request` type? +# function irecv!( +# ref::TracedRArray, +# tag::TracedRNumber, +# src::TracedRNumber; +# location=mlir_stacktrace("mpi.irecv", @__FILE__, @__LINE__), +# ) +# # return TracedRequest( +# # MLIR.IR.result(mpi.irecv(ref.mlir_data, tag.mlir_data, src.mlir_data; location)) +# # ) +# inputs = IR.Value[ref.mlir_data, tag.mlir_data, src.mlir_data] +# sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Irecv") +# rettype = IR.Type[] +# +# IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) +# return ref +# end + function irecv!( - ref::TracedRArray, + buf::TracedRArray, tag::TracedRNumber, src::TracedRNumber; - location=mlir_stacktrace("mpi.irecv", @__FILE__, @__LINE__), + location=mlir_stacktrace("mpi.isend", @__FILE__, @__LINE__), ) - # return TracedRequest( - # MLIR.IR.result(mpi.irecv(ref.mlir_data, tag.mlir_data, src.mlir_data; location)) - # ) - inputs = IR.Value[ref.mlir_data, tag.mlir_data, src.mlir_data] - sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Irecv") - rettype = IR.Type[] - - IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) - return ref + T = Reactant.unwrapped_eltype(buf) + mpi_datatype = convert_julia_type_to_mpi_datatype(T) + mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) + + sym_name = "enzymexla_wrapper_MPI_Irecv_$(mpi_datatype_name)" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!( + "MPI_Irecv", + "llvm.func @MPI_Irecv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32", + ) + + # int MPI_Irecv(void* buf, int count, MPI_Datatype datatype, + # int source, int tag, MPI_Comm comm, MPI_Request* request) + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %src_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr, %req_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %datatype = llvm.mlir.addressof @$(mpi_datatype_name) : !llvm.ptr + %count = llvm.load %count_ptr : !llvm.ptr -> i32 + %src = llvm.load %src_ptr : !llvm.ptr -> i32 + %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 + %res = llvm.call @MPI_Irecv(%buf, %count, %datatype, %src, %tag, %comm, %req_ptr) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + count = Reactant.Ops.constant(Int32(length(buf))) + request = Reactant.Ops.constant(Int64(-1)) + + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 4, 0, C_NULL + ), + ), + ]) + + ret = enzymexla.jit_call( + IR.Value[ + buf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data, request.mlir_data + ]; + fn=sym_attr, + result_0=IR.Type[mlir_type(request)], + output_operand_aliases=output_operand_aliases, + location, + ) + + # return TracedRNumber + request.mlir_data = IR.result(ret) + return request end function wait( @@ -520,7 +581,7 @@ function wait( sym_name = "enzymexla_wrapper_MPI_Wait" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + # IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") IR.inject!("MPI_Wait", "llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32") #! format: off diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index cffa9994d9..ecc17be7ac 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -64,8 +64,8 @@ function MPI.Isend( comm::MPI.Comm, request::TracedRequest=TracedRequest((), nothing) ) - tag = Reactant.Ops.constant(tag) dest = Reactant.Ops.constant(dest) + tag = Reactant.Ops.constant(tag) gen_request = MPI.Isend(buf, dest, tag, comm) request.mlir_data = gen_request.mlir_data @@ -135,27 +135,31 @@ function MPI.Recv!( return Ops.recv!(recvbuf, tag, source) end -# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` -function MPI.Irecv!(recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm) - @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - - tag = if !(tag isa TracedRNumber) - Reactant.Ops.constant(tag) - end - - source = if !(source isa TracedRNumber) - Reactant.Ops.constant(source) - end +function MPI.Irecv!( + buf::TracedRArray, + source::Integer, + tag::Integer, + comm::MPI.Comm, + request::TracedRequest +) + source = Reactant.Ops.constant(dest) + tag = Reactant.Ops.constant(tag) - return Ops.irecv!(recvbuf, tag, source) + gen_request = MPI.Irecv!(buf, source, tag, comm) + request.mlir_data = gen_request.mlir_data + return request end +# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` function MPI.Irecv!( - recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm, req::TracedRequest + buf::TracedRArray, + source::TracedRNumber, + tag::TracedRNumber, + comm::MPI.Comm ) - gen_req = MPI.Irecv!(recvbuf, source, tag, comm) - req.mlir_data = gen_req.mlir_data - return req + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + + return Ops.irecv!(buf, tag, source) end function MPI.Allreduce!(sendbuf::TracedRArray, recvbuf::TracedRArray, op, comm::MPI.Comm) From 56eb8eb8015e6855c1f2a6299f7d996b1841d366 Mon Sep 17 00:00:00 2001 From: romanlee Date: Mon, 1 Sep 2025 13:52:36 -0700 Subject: [PATCH 61/97] Fix overrides Irecv! --- ext/ReactantMPIExt/Overrides.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index ecc17be7ac..fe3b26323c 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -140,9 +140,9 @@ function MPI.Irecv!( source::Integer, tag::Integer, comm::MPI.Comm, - request::TracedRequest + request::TracedRequest=TracedRequest((), nothing) ) - source = Reactant.Ops.constant(dest) + source = Reactant.Ops.constant(source) tag = Reactant.Ops.constant(tag) gen_request = MPI.Irecv!(buf, source, tag, comm) From d7a764c6ef816bd7821494868a93aa758a419898 Mon Sep 17 00:00:00 2001 From: romanlee Date: Wed, 3 Sep 2025 12:57:05 -0700 Subject: [PATCH 62/97] Commit debug testing stuff I will of course remove these before any PR is submitted, just useful to have tracked for dev'ing across multiple machines --- roman-temp-debug/2025.09.mpi/Project.toml | 4 + roman-temp-debug/2025.09.mpi/runtests.sh | 22 ++ roman-temp-debug/2025.09.mpi/sergio.jl | 29 +++ roman-temp-debug/2025.09.mpi/setup.sh | 7 + .../2025.09.mpi/test-isend-irecv.jl | 220 ++++++++++++++++++ .../2025.09.mpi/test-isend-irecv_clean.jl | 126 ++++++++++ .../2025.09.mpi/test-send-recv.jl | 210 +++++++++++++++++ roman-temp-debug/Project.toml | 4 + roman-temp-debug/README.md | 1 + roman-temp-debug/bbb.jl | 14 ++ roman-temp-debug/runtests.sh | 22 ++ roman-temp-debug/setup.sh | 7 + 12 files changed, 666 insertions(+) create mode 100644 roman-temp-debug/2025.09.mpi/Project.toml create mode 100755 roman-temp-debug/2025.09.mpi/runtests.sh create mode 100644 roman-temp-debug/2025.09.mpi/sergio.jl create mode 100644 roman-temp-debug/2025.09.mpi/setup.sh create mode 100644 roman-temp-debug/2025.09.mpi/test-isend-irecv.jl create mode 100644 roman-temp-debug/2025.09.mpi/test-isend-irecv_clean.jl create mode 100644 roman-temp-debug/2025.09.mpi/test-send-recv.jl create mode 100644 roman-temp-debug/Project.toml create mode 100644 roman-temp-debug/README.md create mode 100644 roman-temp-debug/bbb.jl create mode 100755 roman-temp-debug/runtests.sh create mode 100644 roman-temp-debug/setup.sh diff --git a/roman-temp-debug/2025.09.mpi/Project.toml b/roman-temp-debug/2025.09.mpi/Project.toml new file mode 100644 index 0000000000..31b989ee28 --- /dev/null +++ b/roman-temp-debug/2025.09.mpi/Project.toml @@ -0,0 +1,4 @@ +[deps] +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/roman-temp-debug/2025.09.mpi/runtests.sh b/roman-temp-debug/2025.09.mpi/runtests.sh new file mode 100755 index 0000000000..c71f11f386 --- /dev/null +++ b/roman-temp-debug/2025.09.mpi/runtests.sh @@ -0,0 +1,22 @@ +# ------------------- +# perlmutter +# ------------------- +salloc --nodes 1 --qos interactive --time 04:00:00 --constraint gpu --gpus 4 --account=nstaff + +# Flags from https://github.com/PRONTOLab/GB-25/blob/main/sharding/perlmutter_scaling_test.jl +export JULIA_CUDA_MEMORY_POOL=none +export JULIA_CUDA_USE_COMPAT=false + +# Flag from: https://github.com/PRONTOLab/GB-25/blob/main/sharding/common_submission_generator.jl +export XLA_REACTANT_GPU_MEM_FRACTION=0.9 + +srun -n 2 julia --project ./mpi.jl + +# Then added this flag to srun +srun -n 2 --gpus-per-task=1 julia --project ./mpi.jl + + +# ------------------- +# local laptop +# ------------------- +mpiexec -n 2 julia --project mpi.jl diff --git a/roman-temp-debug/2025.09.mpi/sergio.jl b/roman-temp-debug/2025.09.mpi/sergio.jl new file mode 100644 index 0000000000..a044f5284f --- /dev/null +++ b/roman-temp-debug/2025.09.mpi/sergio.jl @@ -0,0 +1,29 @@ +using Reactant +using MPI +using Libdl + +Reactant.set_default_backend("cpu") + +tag = 43 +comm = MPI.COMM_WORLD +source = 1 + +println("Here we go!") + +MPI.Init() + +if MPI.Comm_rank(MPI.COMM_WORLD) == 0 + buffer = Reactant.to_rarray(zeros(Int32, 8)) + println("[$(MPI.Comm_rank(MPI.COMM_WORLD))] before - $buffer") + @jit MPI.Recv!(buffer, source, tag, comm) + println("[$(MPI.Comm_rank(MPI.COMM_WORLD))] after - $buffer") + println(isapprox(buffer, ones(8))) +else + buffer = ones(Int32, 8) + destination = 0 + println("[$(MPI.Comm_rank(MPI.COMM_WORLD))] sending - $buffer") + MPI.Send(buffer, destination, tag, comm) + println("[$(MPI.Comm_rank(MPI.COMM_WORLD))] sent!") +end + +MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/setup.sh b/roman-temp-debug/2025.09.mpi/setup.sh new file mode 100644 index 0000000000..7dd7d6d268 --- /dev/null +++ b/roman-temp-debug/2025.09.mpi/setup.sh @@ -0,0 +1,7 @@ +# how I set up a julia project in this directory +# These commands create Project.toml and Manifest.toml +julia ] +activate . +dev /global/homes/r/romanlee/Documents/codes/Reactant.jl +add MPI +add Test diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl new file mode 100644 index 0000000000..3d8d363460 --- /dev/null +++ b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl @@ -0,0 +1,220 @@ +using Test, MPI, Reactant, InteractiveUtils + +Reactant.set_default_backend("cpu") +# Reactant.set_default_backend("gpu") + +MPI.Init() + +comm = MPI.COMM_WORLD +rank = MPI.Comm_rank(comm) +nranks = MPI.Comm_size(comm) + +# # -------------------------- +# # test MPI.jl Isend / Irecv! +# # -------------------------- +# # Skip test if not enough processes +# if nranks < 2 +# @error "Need at least 2 MPI processes for Isend/Irecv test" +# end + +# send_buf = [1, 2, 3, 4, 5] +# tag = 42 +# if rank == 0 +# dest = 1 + +# req_send = MPI.Isend(send_buf, dest, tag, comm) + +# println("Rank 0: Waiting...") + +# MPI.Wait(req_send) + +# println("Rank 0: Sent") + +# elseif rank == 1 +# recv_buf = Vector{Int}(undef, 5) +# source = 0 + +# req_recv = MPI.Irecv!(recv_buf, source, tag, comm) + +# println("Rank 1: Waiting...") + +# status = MPI.Wait(req_recv) + +# println( "Rank 1: Received: $(recv_buf == send_buf)" ) +# # @test MPI.Get_source(status) == 0 +# # @test MPI.Get_tag(status) == 42 + +# end +# # -------------------------- + + +# -------------------------- +# # test Reactant Isend +# -------------------------- +# if nranks < 2 +# @error "Need at least 2 MPI processes for Isend/Irecv test" +# end +# +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# tag = 42 +# if rank == 0 +# dest = 1 + +# req_send = @jit MPI.Isend(send_buf, dest, tag, comm) + +# MPI.Wait(req_send) + +# elseif rank == 1 +# recv_buf = Vector{Int}(undef, 5) +# source = 0 + +# req_recv = MPI.Irecv!(recv_buf, source, tag, comm) + +# status = MPI.Wait(req_recv) + +# println( recv_buf == send_buf ) +# # @test MPI.Get_source(status) == 0 +# # @test MPI.Get_tag(status) == 42 + +# end + + +# -------------------------- +# debug +# -------------------------- +# # runs without crashing +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# tag = 42 +# dest = 1 +# function Isend_Wait(send_buf, dest, tag, comm) +# req = MPI.Isend(send_buf, dest, tag, comm) +# MPI.Wait(req) +# return nothing +# end +# # @jit Isend_Wait(send_buf, dest, tag, comm) +# println(@code_hlo optimize=false Isend_Wait(send_buf, dest, tag, comm)) + + +# # runs without crashing +# recv_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# tag = 42 +# src = 1 +# # function Irecv_Wait(recv_buf, src, tag, comm) +# function Irecv_Wait(recv_buf, src, tag, comm) +# req = MPI.Irecv!(recv_buf, src, tag, comm) +# MPI.Wait(req) +# return nothing +# end +# # @jit Irecv_Wait(recv_buf, src, tag, comm) +# println(@code_hlo optimize=false Irecv_Wait(recv_buf, src, tag, comm)) + + +# # recv_buf not modified +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# function Isend_Irecv!(comm, rank, send_buf, recv_buf) +# if rank==0 +# dest = 1 +# tag = 42 +# req = MPI.Isend(send_buf, dest, tag, comm) +# MPI.Wait(req) +# elseif rank==1 +# src = 0 +# tag = 42 +# req = MPI.Irecv!(recv_buf, src, tag, comm) +# MPI.Wait(req) +# end + +# return recv_buf +# end +# # recv_buf = @jit Isend_Irecv!(comm, rank, send_buf, recv_buf) +# +# rank==1 && sleep(3) +# println("\nRank: $rank") +# println(@code_hlo optimize=false Isend_Irecv!(comm, rank, send_buf, recv_buf)) + + +# # hangs +# # send_buf = ConcreteRArray(fill(1)) +# # recv_buf = ConcreteRArray(fill(12)) +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# function aaa(comm, rank, send_buf, recv_buf, tag) +# if rank == 0 +# # dest = 1 +# dest = 333 +# MPI.Send(send_buf, dest, tag, comm) +# elseif rank == 1 +# # src = 0 +# src = 555 +# MPI.Recv!(recv_buf, src, tag, comm) +# # println( recv_buf == send_buf ) +# end +# return nothing +# end +# # @jit aaa(comm, rank, send_buf, recv_buf, tag) +# rank==1 && sleep(5) +# println("\nRank: $rank") +# # println(@code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) + +# # bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) +# # if rank==0 +# # println("\nlowered") +# # println(@code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) +# # println("\ntyped") +# # println(@code_typed bbb(comm, rank, send_buf, recv_buf, tag)) +# # println("\nllvm") +# # println(@code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) +# # end + + +# # works +# # send_buf = ConcreteRArray(fill(1)) +# # recv_buf = ConcreteRArray(fill(12)) +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# if rank == 0 +# # dest = 1 +# dest = 333 +# # @jit MPI.Send(send_buf, dest, tag, comm) +# println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) +# elseif rank == 1 +# # src = 0 +# src = 555 +# # @jit MPI.Recv!(recv_buf, src, tag, comm) +# println(@code_hlo optimize=false MPI.Recv!(recv_buf, src, tag, comm)) +# println( recv_buf == send_buf ) +# end + + + +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# if rank == 0 +# dest = 333 +# bbb = @compile MPI.Send(send_buf, dest, tag, comm) + +# # println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) + +# # println(@code_lowered bbb(send_buf, dest, tag, comm)) + +# println("\nlowered") +# println(@code_lowered bbb(send_buf, dest, tag, comm)) +# println("\ntyped") +# println(@code_typed bbb(send_buf, dest, tag, comm)) +# println("\nllvm") +# println(@code_llvm bbb(send_buf, dest, tag, comm)) + +# # elseif rank == 1 +# # # # src = 0 +# # # src = 555 +# # # # @jit MPI.Recv!(recv_buf, src, tag, comm) +# # # println(@code_hlo optimize=false MPI.Recv!(recv_buf, src, tag, comm)) +# # # println( recv_buf == send_buf ) +# end + + +MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv_clean.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv_clean.jl new file mode 100644 index 0000000000..d1b9c68d85 --- /dev/null +++ b/roman-temp-debug/2025.09.mpi/test-isend-irecv_clean.jl @@ -0,0 +1,126 @@ +using Test, MPI, Reactant + +Reactant.set_default_backend("cpu") +# Reactant.set_default_backend("gpu") + +MPI.Init() + +comm = MPI.COMM_WORLD +rank = MPI.Comm_rank(comm) +nranks = MPI.Comm_size(comm) + +println("BEFORE: rank $rank") + +# # runs without crashing +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# tag = 42 +# dest = 1 +# function Isend_Wait(send_buf, dest, tag, comm) +# req = MPI.Isend(send_buf, dest, tag, comm) +# MPI.Wait(req) +# return nothing +# end +# @jit Isend_Wait(send_buf, dest, tag, comm) + + + +# # runs without crashing +# recv_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# tag = 42 +# src = 1 +# function Irecv!_Wait(recv_buf, src, tag, comm) +# req = MPI.Irecv!(recv_buf, src, tag, comm) +# MPI.Wait(req) +# return nothing +# end +# @jit Irecv!_Wait(recv_buf, src, tag, comm) + + +# # recv_buf not modified +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# function Isend_Irecv!(comm, rank, send_buf, recv_buf) +# if rank==0 +# println("rank 0") + +# dest = 1 +# tag = 42 + +# req = MPI.Isend(send_buf, dest, tag, comm) +# MPI.Wait(req) +# elseif rank==1 +# println("rank 1") + +# src = 0 +# tag = 42 + +# req = MPI.Irecv!(recv_buf, src, tag, comm) +# MPI.Wait(req) +# end + +# return recv_buf +# end +# recv_buf = @jit Isend_Irecv!(comm, rank, send_buf, recv_buf) +# println(recv_buf) + + +# # hangs +# send_buf = ConcreteRArray(fill(1)) +# recv_buf = ConcreteRArray(fill(12)) +# tag = 43 +# function aaa(comm, rank, send_buf, recv_buf, tag) +# if rank == 0 +# dest = 1 +# MPI.Send(send_buf, dest, tag, comm) +# elseif rank == 1 +# src = 0 +# MPI.Recv!(recv_buf, src, tag, comm) +# # println( recv_buf == send_buf ) +# end +# return nothing +# end +# # @jit aaa(comm, rank, send_buf, recv_buf, tag) +# # display(@code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) +# display(@code_xla aaa(comm, rank, send_buf, recv_buf, tag)) + + +# # works +# send_buf = ConcreteRArray(fill(1)) +# recv_buf = ConcreteRArray(fill(12)) +# tag = 43 +# if rank == 0 +# dest = 1 +# @jit MPI.Send(send_buf, dest, tag, comm) +# elseif rank == 1 +# src = 0 +# @jit MPI.Recv!(recv_buf, src, tag, comm) +# println( recv_buf == send_buf ) +# end + + + +# # hangs debug +# send_buf = ConcreteRArray(fill(1)) +# recv_buf = ConcreteRArray(fill(12)) +# tag = 43 +# function aaa(comm, rank, send_buf, recv_buf, tag) +# # if rank == 0 +# dest = 1 +# MPI.Send(send_buf, dest, tag, comm) +# # elseif rank == 1 +# src = 0 +# MPI.Recv!(recv_buf, src, tag, comm) +# # println( recv_buf == send_buf ) +# # end +# return nothing +# end +# # @jit aaa(comm, rank, send_buf, recv_buf, tag) +# # display(@code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) +# display(@code_xla aaa(comm, rank, send_buf, recv_buf, tag)) + + + + +println("AFTER: rank $rank") + +MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/test-send-recv.jl b/roman-temp-debug/2025.09.mpi/test-send-recv.jl new file mode 100644 index 0000000000..bdbec587b2 --- /dev/null +++ b/roman-temp-debug/2025.09.mpi/test-send-recv.jl @@ -0,0 +1,210 @@ +using Test, MPI, Reactant + +Reactant.set_default_backend("cpu") +# Reactant.set_default_backend("gpu") + +MPI.Init() + +# println(@code_hlo optimize=false MPI.Comm_rank(MPI.COMM_WORLD)) +# println(@code_hlo optimize=true MPI.Comm_rank(MPI.COMM_WORLD)) + +# pass on cpu +# fail on gpu: segfault when trying to return res in Ops.jl comm_rank +@testset "Comm_rank" begin + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + @test rank == @jit MPI.Comm_rank(comm) +end + +# pass on cpu +# fail on gpu: segfaulta upon trying to return res in Ops.jl comm_size +@testset "Comm_size" begin + comm = MPI.COMM_WORLD + nranks = MPI.Comm_size(comm) + @test nranks == @jit MPI.Comm_size(comm) +end + +@testset "Allreduce" begin + comm = MPI.COMM_WORLD + nranks = MPI.Comm_size(comm) + + # test good-ol MPI.jl allreduce + @test nranks == MPI.Allreduce(1, MPI.SUM, MPI.COMM_WORLD) + + # pass on cpu + # pass on gpu! + # test Reactant allreduce + @test nranks == @jit MPI.Allreduce(1, MPI.SUM, MPI.COMM_WORLD) +end + +@testset "Send, Recv!" begin + + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + nranks = MPI.Comm_size(comm) + + # "Need at least 2 MPI processes for send tests" + if nranks < 2 + @warn "need more than 2 mpi ranks, skipping" + return + end + + # test MPI.jl Send/Recv + @testset "MPI.jl Send / Recv!" begin + send_buf = fill(1) + tag = 43 + if rank == 0 + MPI.Send(send_buf, comm; dest=1, tag=tag) + @test true # Send completed + elseif rank == 1 + recv_buf = fill(12) + MPI.Recv!(recv_buf, comm; source=0, tag=tag) + @test recv_buf == send_buf + end + end + + # test MPI.jl Send / Reactant Recv + @testset "MPI.jl Send / Reactant Recv!" begin + send_buf = fill(1) + tag = 43 + if rank == 0 + MPI.Send(send_buf, comm; dest=1, tag=tag) + @test true + elseif rank == 1 + recv_buf = ConcreteRArray(fill(12)) + source = 0 + @jit MPI.Recv!(recv_buf, source, tag, comm) + @test recv_buf == send_buf + end + end + + # test Reactant Send / MPI.jl Recv + @testset "Reactant Send / MPI.jl Recv!" begin + send_buf = ConcreteRArray(fill(1)) + tag = 43 + if rank == 0 + dest = 1 + @jit MPI.Send(send_buf, dest, tag, comm) + @test true + elseif rank == 1 + recv_buf = fill(12) + MPI.Recv!(recv_buf, comm; source=0, tag=tag) + @test recv_buf == send_buf + end + end + + # test Reactant Send/Recv + @testset "Reactant Send / Recv!" begin + send_buf = ConcreteRArray(fill(1)) + tag = 43 + if rank == 0 + # Send: pass on cpu, pass on gpu + dest = 1 + @jit MPI.Send(send_buf, dest, tag, comm) + @test true # Send completed + elseif rank == 1 + # hang on cpu + # segfault on gpu upon trying to reference res + recv_buf = ConcreteRArray(fill(12)) + src = 0 + @jit MPI.Recv!(recv_buf, src, tag, comm) + @test recv_buf == send_buf + end + end +end + +# ---------- +# debug +# ---------- +# comm = MPI.COMM_WORLD +# rank = MPI.Comm_rank(comm) +# nranks = MPI.Comm_size(comm) + +# send_buf = ConcreteRArray(fill(1)) +# tag = 43 +# dest = 1 +# # @jit dbSend(send_buf, dest, tag, comm) +# # @jit MPI.Senddd(send_buf, dest, tag, comm) +# # @jit Senddd(send_buf, dest, tag, comm) +# @jit func_foo() + +# if nranks < 2 +# @error "Need at least 2 MPI processes for send tests. Skipping." +# end + +# # test Reactant Send/Recv +# send_buf = ConcreteRArray(fill(1)) +# tag = 43 +# if rank == 0 +# # Send: pass on cpu, pass on gpu +# # dest = 1 +# dest = 1 +# # @jit MPI.Send(send_buf, dest, tag, comm) +# elseif rank == 1 +# # # hang on cpu +# # # segfault on gpu upon trying to reference res +# # recv_buf = ConcreteRArray(fill(12)) +# # src = 0 +# # @jit MPI.Recv!(recv_buf, src, tag, comm) +# end + + + + +# # # test Reactant Send/Recv +# # send_buf = ConcreteRArray(fill(1)) +# # if rank == 0 +# # # Send: pass on cpu, pass on gpu +# # @jit MPI.Send(send_buf, 1, 0, comm) + +# # dest = 12 +# # tag = 33 +# # println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) + +# # # @test true # Send completed +# # elseif rank == 1 +# # # # hang on cpu +# # # # segfault on gpu upon trying to reference res +# # # recv_buf = ConcreteRArray(fill(12)) +# # # # @jit MPI.Recv!(recv_buf, 0, 0, comm) +# # # source = 12 +# # # tag = 35 +# # # println(@code_hlo optimize=false MPI.Recv!(recv_buf, source, tag, comm)) +# # # # @test recv_buf == send_buf + +# # # # println(@code_hlo MPI.Recv!(recv_buf, 0, 0, comm)) +# # end + +# send_buf = ConcreteRArray(fill(1)) +# tag = 43 +# if rank == 0 +# dest = 3333 + +# println("@code_hlo optimize=false:") +# println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) +# println("") + +# # println("@code_hlo optimize=:before_jit:") +# # println(@code_hlo optimize=:before_jit MPI.Send(send_buf, dest, tag, comm)) +# # println("") + +# # println("@jit MPI.Send:") +# # @jit MPI.Send(send_buf, dest, tag, comm) + +# elseif rank == 1 +# # recv_buf = ConcreteRArray(fill(12)) +# # source = 0 + +# # println("code hlo:") +# # println(@code_hlo optimize=false MPI.Recv!(recv_buf, source, tag, comm)) +# # println("") + +# # println("@jit MPI.Recv!:") +# # @jit MPI.Recv!(recv_buf, source, tag, comm) + +# # # # println("after ", recv_buf==send_buf) +# end + + + +MPI.Finalize() diff --git a/roman-temp-debug/Project.toml b/roman-temp-debug/Project.toml new file mode 100644 index 0000000000..31b989ee28 --- /dev/null +++ b/roman-temp-debug/Project.toml @@ -0,0 +1,4 @@ +[deps] +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/roman-temp-debug/README.md b/roman-temp-debug/README.md new file mode 100644 index 0000000000..04c00d9276 --- /dev/null +++ b/roman-temp-debug/README.md @@ -0,0 +1 @@ +I will delete this before we submit the PR diff --git a/roman-temp-debug/bbb.jl b/roman-temp-debug/bbb.jl new file mode 100644 index 0000000000..2e640a283b --- /dev/null +++ b/roman-temp-debug/bbb.jl @@ -0,0 +1,14 @@ +g(x::Float64, y) = 2x + y +display(g) + +# g(x, y::Float64) = x + 2y +# display(g) + +# println(g(2.0, 3)) + +# println(g(2, 3.0)) + +# println(g(2.0, 3.0)) + +g(x::Number, y) = 2x + y +println(g(2.0, 3)) diff --git a/roman-temp-debug/runtests.sh b/roman-temp-debug/runtests.sh new file mode 100755 index 0000000000..c71f11f386 --- /dev/null +++ b/roman-temp-debug/runtests.sh @@ -0,0 +1,22 @@ +# ------------------- +# perlmutter +# ------------------- +salloc --nodes 1 --qos interactive --time 04:00:00 --constraint gpu --gpus 4 --account=nstaff + +# Flags from https://github.com/PRONTOLab/GB-25/blob/main/sharding/perlmutter_scaling_test.jl +export JULIA_CUDA_MEMORY_POOL=none +export JULIA_CUDA_USE_COMPAT=false + +# Flag from: https://github.com/PRONTOLab/GB-25/blob/main/sharding/common_submission_generator.jl +export XLA_REACTANT_GPU_MEM_FRACTION=0.9 + +srun -n 2 julia --project ./mpi.jl + +# Then added this flag to srun +srun -n 2 --gpus-per-task=1 julia --project ./mpi.jl + + +# ------------------- +# local laptop +# ------------------- +mpiexec -n 2 julia --project mpi.jl diff --git a/roman-temp-debug/setup.sh b/roman-temp-debug/setup.sh new file mode 100644 index 0000000000..7dd7d6d268 --- /dev/null +++ b/roman-temp-debug/setup.sh @@ -0,0 +1,7 @@ +# how I set up a julia project in this directory +# These commands create Project.toml and Manifest.toml +julia ] +activate . +dev /global/homes/r/romanlee/Documents/codes/Reactant.jl +add MPI +add Test From 72e73fb533e06382725ce7d4b5da07acef632711 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 4 Sep 2025 15:22:03 -0700 Subject: [PATCH 63/97] Update test-isend-irecv.jl --- .../2025.09.mpi/test-isend-irecv.jl | 84 ++++++++++++------- 1 file changed, 54 insertions(+), 30 deletions(-) diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl index 3d8d363460..bdc2cb5e2b 100644 --- a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl +++ b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl @@ -91,7 +91,7 @@ nranks = MPI.Comm_size(comm) # MPI.Wait(req) # return nothing # end -# # @jit Isend_Wait(send_buf, dest, tag, comm) +# @jit Isend_Wait(send_buf, dest, tag, comm) # println(@code_hlo optimize=false Isend_Wait(send_buf, dest, tag, comm)) @@ -135,56 +135,80 @@ nranks = MPI.Comm_size(comm) # # hangs -# # send_buf = ConcreteRArray(fill(1)) -# # recv_buf = ConcreteRArray(fill(12)) # send_buf = ConcreteRArray([1, 2, 3, 4, 5]) # recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) # tag = 43 # function aaa(comm, rank, send_buf, recv_buf, tag) # if rank == 0 -# # dest = 1 -# dest = 333 +# dest = 1 # MPI.Send(send_buf, dest, tag, comm) # elseif rank == 1 -# # src = 0 -# src = 555 +# src = 0 # MPI.Recv!(recv_buf, src, tag, comm) # # println( recv_buf == send_buf ) # end # return nothing # end -# # @jit aaa(comm, rank, send_buf, recv_buf, tag) -# rank==1 && sleep(5) +# @jit aaa(comm, rank, send_buf, recv_buf, tag) # println("\nRank: $rank") -# # println(@code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) -# # bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) -# # if rank==0 -# # println("\nlowered") -# # println(@code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) -# # println("\ntyped") -# # println(@code_typed bbb(comm, rank, send_buf, recv_buf, tag)) -# # println("\nllvm") -# # println(@code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) -# # end + +send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +tag = 43 +function aaa(comm, rank, send_buf, recv_buf, tag) + if rank == 0 + dest = 333 + MPI.Send(send_buf, dest, tag, comm) + elseif rank == 1 + src = 555 + MPI.Recv!(recv_buf, src, tag, comm) + # println( recv_buf == send_buf ) + end + return nothing +end +# @jit aaa(comm, rank, send_buf, recv_buf, tag) +rank==1 && sleep(5) +# println("\nRank $rank:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) +println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) + +# # # bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) +# # # if rank==0 +# # # println("\nlowered") +# # # println(@code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) +# # # println("\ntyped") +# # # println(@code_typed bbb(comm, rank, send_buf, recv_buf, tag)) +# # # println("\nllvm") +# # # println(@code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) +# # # end # # works -# # send_buf = ConcreteRArray(fill(1)) -# # recv_buf = ConcreteRArray(fill(12)) # send_buf = ConcreteRArray([1, 2, 3, 4, 5]) # recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) # tag = 43 # if rank == 0 -# # dest = 1 -# dest = 333 -# # @jit MPI.Send(send_buf, dest, tag, comm) -# println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) +# dest = 1 +# @jit MPI.Send(send_buf, dest, tag, comm) +# elseif rank == 1 +# src = 0 +# @jit MPI.Recv!(recv_buf, src, tag, comm) +# println( recv_buf == send_buf ) +# end + +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# if rank == 0 +# dest = 1 +# # dest = 333 +# @jit MPI.Send(send_buf, dest, tag, comm) +# # println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) # elseif rank == 1 -# # src = 0 -# src = 555 -# # @jit MPI.Recv!(recv_buf, src, tag, comm) -# println(@code_hlo optimize=false MPI.Recv!(recv_buf, src, tag, comm)) +# src = 0 +# # src = 555 +# @jit MPI.Recv!(recv_buf, src, tag, comm) +# # println(@code_hlo optimize=false MPI.Recv!(recv_buf, src, tag, comm)) # println( recv_buf == send_buf ) # end @@ -217,4 +241,4 @@ nranks = MPI.Comm_size(comm) # end -MPI.Finalize() +# MPI.Finalize() From f200a6fcdb1ded68de662ffa95890aec3d59ce82 Mon Sep 17 00:00:00 2001 From: romanlee Date: Mon, 8 Sep 2025 11:19:15 -0700 Subject: [PATCH 64/97] Clean up formatting --- ext/ReactantMPIExt/Ops.jl | 2 ++ ext/ReactantMPIExt/Overrides.jl | 43 ++++++++++++++++++++------------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 96750c4ff5..b3e3c108fc 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -581,6 +581,8 @@ function wait( sym_name = "enzymexla_wrapper_MPI_Wait" sym_attr = IR.FlatSymbolRefAttribute(sym_name) + # # TODO Temporarily commented out bc otherwise can't compile together with any other + # # func that tries to inject the same thing (e.g., isend, irecv) # IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") IR.inject!("MPI_Wait", "llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32") diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index fe3b26323c..535ab617cd 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -105,34 +105,43 @@ end # return Ops.isend(buf, tag, dest) # end -function MPI.Recv!(buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm) - tag = Reactant.Ops.constant(tag) - source = Reactant.Ops.constant(source) - return MPI.Recv!(buf, source, tag, comm) -end - function MPI.Recv!( - recvbuf::TracedRArray, + buf::TracedRArray, source::Integer, tag::Integer, - comm::MPI.Comm, - ::Type{MPI.API.MPI_Status}, + comm::MPI.Comm ) - return MPI.Recv!(recvbuf, source, tag, comm) + tag = Reactant.Ops.constant(tag) + source = Reactant.Ops.constant(source) + return MPI.Recv!(buf, source, tag, comm) end -function MPI.Recv!( - recvbuf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm, ::Nothing -) - return MPI.Recv!(recvbuf, source, tag, comm) -end +# TODO Do we need these? Comment out at least until everything is working +# function MPI.Recv!( +# buf::TracedRArray, +# source::Integer, +# tag::Integer, +# comm::MPI.Comm, +# ::Type{MPI.API.MPI_Status}, +# ) +# return MPI.Recv!(buf, source, tag, comm) +# end + +# function MPI.Recv!( +# buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm, ::Nothing +# ) +# return MPI.Recv!(buf, source, tag, comm) +# end # TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` function MPI.Recv!( - recvbuf::TracedRArray, source::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm + buf::TracedRArray, + source::TracedRNumber, + tag::TracedRNumber, + comm::MPI.Comm ) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - return Ops.recv!(recvbuf, tag, source) + return Ops.recv!(buf, tag, source) end function MPI.Irecv!( From 8f12edd3324aa7474907a3c98468c49f28428661 Mon Sep 17 00:00:00 2001 From: romanlee Date: Mon, 8 Sep 2025 17:27:34 -0700 Subject: [PATCH 65/97] Send/Recv altogether works as long as message size is large enough If the message size is small (<64KB) Send buffers the message and doesn't block. This seems to screw things up somehow --- .../2025.09.mpi/test-isend-irecv.jl | 174 ++---------- ...nd-irecv_clean.jl => test-isend-irecv0.jl} | 0 .../2025.09.mpi/test-isend-irecv1.jl | 252 ++++++++++++++++++ 3 files changed, 280 insertions(+), 146 deletions(-) rename roman-temp-debug/2025.09.mpi/{test-isend-irecv_clean.jl => test-isend-irecv0.jl} (100%) create mode 100644 roman-temp-debug/2025.09.mpi/test-isend-irecv1.jl diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl index bdc2cb5e2b..0c776b9362 100644 --- a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl +++ b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl @@ -9,168 +9,50 @@ comm = MPI.COMM_WORLD rank = MPI.Comm_rank(comm) nranks = MPI.Comm_size(comm) -# # -------------------------- -# # test MPI.jl Isend / Irecv! -# # -------------------------- -# # Skip test if not enough processes -# if nranks < 2 -# @error "Need at least 2 MPI processes for Isend/Irecv test" -# end - -# send_buf = [1, 2, 3, 4, 5] -# tag = 42 -# if rank == 0 -# dest = 1 - -# req_send = MPI.Isend(send_buf, dest, tag, comm) - -# println("Rank 0: Waiting...") - -# MPI.Wait(req_send) - -# println("Rank 0: Sent") - -# elseif rank == 1 -# recv_buf = Vector{Int}(undef, 5) -# source = 0 - -# req_recv = MPI.Irecv!(recv_buf, source, tag, comm) - -# println("Rank 1: Waiting...") - -# status = MPI.Wait(req_recv) - -# println( "Rank 1: Received: $(recv_buf == send_buf)" ) -# # @test MPI.Get_source(status) == 0 -# # @test MPI.Get_tag(status) == 42 - -# end -# # -------------------------- - - -# -------------------------- -# # test Reactant Isend -# -------------------------- -# if nranks < 2 -# @error "Need at least 2 MPI processes for Isend/Irecv test" -# end -# -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# tag = 42 -# if rank == 0 -# dest = 1 - -# req_send = @jit MPI.Isend(send_buf, dest, tag, comm) - -# MPI.Wait(req_send) - -# elseif rank == 1 -# recv_buf = Vector{Int}(undef, 5) -# source = 0 - -# req_recv = MPI.Irecv!(recv_buf, source, tag, comm) - -# status = MPI.Wait(req_recv) - -# println( recv_buf == send_buf ) -# # @test MPI.Get_source(status) == 0 -# # @test MPI.Get_tag(status) == 42 - -# end - - -# -------------------------- -# debug -# -------------------------- -# # runs without crashing -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# tag = 42 -# dest = 1 -# function Isend_Wait(send_buf, dest, tag, comm) -# req = MPI.Isend(send_buf, dest, tag, comm) -# MPI.Wait(req) -# return nothing -# end -# @jit Isend_Wait(send_buf, dest, tag, comm) -# println(@code_hlo optimize=false Isend_Wait(send_buf, dest, tag, comm)) - - -# # runs without crashing -# recv_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# tag = 42 -# src = 1 -# # function Irecv_Wait(recv_buf, src, tag, comm) -# function Irecv_Wait(recv_buf, src, tag, comm) -# req = MPI.Irecv!(recv_buf, src, tag, comm) -# MPI.Wait(req) -# return nothing -# end -# # @jit Irecv_Wait(recv_buf, src, tag, comm) -# println(@code_hlo optimize=false Irecv_Wait(recv_buf, src, tag, comm)) +send_buf = ConcreteRArray([1, 2, 3, 4, 5]) # hangs with small bufs, Send no block, Recv no receive send, no bueno +recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# send_buf = zeros(UInt8, 65536) # works with bufs > 65KB +# recv_buf = ones(UInt8, 65536) +tag = 43 +function aaa(comm, rank, send_buf, recv_buf, tag) + if rank == 0 + dest = 1 + # ccall(:jl_breakpoint, Cvoid, (Any,), dest) + MPI.Send(send_buf, dest, tag, comm) + elseif rank == 1 + src = 0 + MPI.Recv!(recv_buf, src, tag, comm) + println( recv_buf == send_buf ) + end + return nothing +end -# # recv_buf not modified -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# function Isend_Irecv!(comm, rank, send_buf, recv_buf) -# if rank==0 -# dest = 1 -# tag = 42 -# req = MPI.Isend(send_buf, dest, tag, comm) -# MPI.Wait(req) -# elseif rank==1 -# src = 0 -# tag = 42 -# req = MPI.Irecv!(recv_buf, src, tag, comm) -# MPI.Wait(req) -# end +@jit aaa(comm, rank, send_buf, recv_buf, tag) +println("Rank $rank") -# return recv_buf -# end -# # recv_buf = @jit Isend_Irecv!(comm, rank, send_buf, recv_buf) -# -# rank==1 && sleep(3) -# println("\nRank: $rank") -# println(@code_hlo optimize=false Isend_Irecv!(comm, rank, send_buf, recv_buf)) +# rank==1 && sleep(5) +# println("\nRank $rank:\n", @code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) -# # hangs # send_buf = ConcreteRArray([1, 2, 3, 4, 5]) # recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) # tag = 43 # function aaa(comm, rank, send_buf, recv_buf, tag) # if rank == 0 -# dest = 1 +# dest = 333 # MPI.Send(send_buf, dest, tag, comm) # elseif rank == 1 -# src = 0 +# src = 555 # MPI.Recv!(recv_buf, src, tag, comm) # # println( recv_buf == send_buf ) # end # return nothing # end -# @jit aaa(comm, rank, send_buf, recv_buf, tag) -# println("\nRank: $rank") - - -send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -tag = 43 -function aaa(comm, rank, send_buf, recv_buf, tag) - if rank == 0 - dest = 333 - MPI.Send(send_buf, dest, tag, comm) - elseif rank == 1 - src = 555 - MPI.Recv!(recv_buf, src, tag, comm) - # println( recv_buf == send_buf ) - end - return nothing -end -# @jit aaa(comm, rank, send_buf, recv_buf, tag) -rank==1 && sleep(5) -# println("\nRank $rank:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) -println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) +# # @jit aaa(comm, rank, send_buf, recv_buf, tag) +# rank==1 && sleep(5) +# # println("\nRank $rank:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) +# println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) # # # bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) # # # if rank==0 @@ -241,4 +123,4 @@ println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) # end -# MPI.Finalize() +MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv_clean.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv0.jl similarity index 100% rename from roman-temp-debug/2025.09.mpi/test-isend-irecv_clean.jl rename to roman-temp-debug/2025.09.mpi/test-isend-irecv0.jl diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv1.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv1.jl new file mode 100644 index 0000000000..12b48afc6a --- /dev/null +++ b/roman-temp-debug/2025.09.mpi/test-isend-irecv1.jl @@ -0,0 +1,252 @@ +using Test, MPI, Reactant, InteractiveUtils + +Reactant.set_default_backend("cpu") +# Reactant.set_default_backend("gpu") + +MPI.Init() + +comm = MPI.COMM_WORLD +rank = MPI.Comm_rank(comm) +nranks = MPI.Comm_size(comm) + +# # -------------------------- +# # test MPI.jl Isend / Irecv! +# # -------------------------- +# # Skip test if not enough processes +# if nranks < 2 +# @error "Need at least 2 MPI processes for Isend/Irecv test" +# end + +# send_buf = [1, 2, 3, 4, 5] +# tag = 42 +# if rank == 0 +# dest = 1 + +# req_send = MPI.Isend(send_buf, dest, tag, comm) + +# println("Rank 0: Waiting...") + +# MPI.Wait(req_send) + +# println("Rank 0: Sent") + +# elseif rank == 1 +# recv_buf = Vector{Int}(undef, 5) +# source = 0 + +# req_recv = MPI.Irecv!(recv_buf, source, tag, comm) + +# println("Rank 1: Waiting...") + +# status = MPI.Wait(req_recv) + +# println( "Rank 1: Received: $(recv_buf == send_buf)" ) +# # @test MPI.Get_source(status) == 0 +# # @test MPI.Get_tag(status) == 42 + +# end +# # -------------------------- + + +# -------------------------- +# # test Reactant Isend +# -------------------------- +# if nranks < 2 +# @error "Need at least 2 MPI processes for Isend/Irecv test" +# end +# +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# tag = 42 +# if rank == 0 +# dest = 1 + +# req_send = @jit MPI.Isend(send_buf, dest, tag, comm) + +# MPI.Wait(req_send) + +# elseif rank == 1 +# recv_buf = Vector{Int}(undef, 5) +# source = 0 + +# req_recv = MPI.Irecv!(recv_buf, source, tag, comm) + +# status = MPI.Wait(req_recv) + +# println( recv_buf == send_buf ) +# # @test MPI.Get_source(status) == 0 +# # @test MPI.Get_tag(status) == 42 + +# end + + +# -------------------------- +# debug +# -------------------------- +# # runs without crashing +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# tag = 42 +# dest = 1 +# function Isend_Wait(send_buf, dest, tag, comm) +# req = MPI.Isend(send_buf, dest, tag, comm) +# MPI.Wait(req) +# return nothing +# end +# @jit Isend_Wait(send_buf, dest, tag, comm) +# println(@code_hlo optimize=false Isend_Wait(send_buf, dest, tag, comm)) + + +# # runs without crashing +# recv_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# tag = 42 +# src = 1 +# # function Irecv_Wait(recv_buf, src, tag, comm) +# function Irecv_Wait(recv_buf, src, tag, comm) +# req = MPI.Irecv!(recv_buf, src, tag, comm) +# MPI.Wait(req) +# return nothing +# end +# # @jit Irecv_Wait(recv_buf, src, tag, comm) +# println(@code_hlo optimize=false Irecv_Wait(recv_buf, src, tag, comm)) + + +# # recv_buf not modified +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# function Isend_Irecv!(comm, rank, send_buf, recv_buf) +# if rank==0 +# dest = 1 +# tag = 42 +# req = MPI.Isend(send_buf, dest, tag, comm) +# MPI.Wait(req) +# elseif rank==1 +# src = 0 +# tag = 42 +# req = MPI.Irecv!(recv_buf, src, tag, comm) +# MPI.Wait(req) +# end + +# return recv_buf +# end +# # recv_buf = @jit Isend_Irecv!(comm, rank, send_buf, recv_buf) +# +# rank==1 && sleep(3) +# println("\nRank: $rank") +# println(@code_hlo optimize=false Isend_Irecv!(comm, rank, send_buf, recv_buf)) + + +# # hangs +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# function aaa(comm, rank, send_buf, recv_buf, tag) +# if rank == 0 +# dest = 1 +# # ccall(:jl_breakpoint, Cvoid, (Any,), dest) +# MPI.Send(send_buf, dest, tag, comm) +# # while true +# # sleep(5) +# # end +# elseif rank == 1 +# src = 0 +# # ccall(:jl_breakpoint, Cvoid, (Any,), src) +# MPI.Recv!(recv_buf, src, tag, comm) +# # println( recv_buf == send_buf ) +# # while true +# # sleep(5) +# # end +# end +# return nothing +# end +# @jit aaa(comm, rank, send_buf, recv_buf, tag) +# # println("\nRank: $rank") + + +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# function aaa(comm, rank, send_buf, recv_buf, tag) +# if rank == 0 +# dest = 333 +# MPI.Send(send_buf, dest, tag, comm) +# elseif rank == 1 +# src = 555 +# MPI.Recv!(recv_buf, src, tag, comm) +# # println( recv_buf == send_buf ) +# end +# return nothing +# end +# # @jit aaa(comm, rank, send_buf, recv_buf, tag) +# rank==1 && sleep(5) +# # println("\nRank $rank:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) +# println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) + +# # # bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) +# # # if rank==0 +# # # println("\nlowered") +# # # println(@code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) +# # # println("\ntyped") +# # # println(@code_typed bbb(comm, rank, send_buf, recv_buf, tag)) +# # # println("\nllvm") +# # # println(@code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) +# # # end + + +# # works +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# if rank == 0 +# dest = 1 +# @jit MPI.Send(send_buf, dest, tag, comm) +# elseif rank == 1 +# src = 0 +# @jit MPI.Recv!(recv_buf, src, tag, comm) +# println( recv_buf == send_buf ) +# end + +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# if rank == 0 +# dest = 1 +# # dest = 333 +# @jit MPI.Send(send_buf, dest, tag, comm) +# # println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) +# elseif rank == 1 +# src = 0 +# # src = 555 +# @jit MPI.Recv!(recv_buf, src, tag, comm) +# # println(@code_hlo optimize=false MPI.Recv!(recv_buf, src, tag, comm)) +# println( recv_buf == send_buf ) +# end + + + +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# if rank == 0 +# dest = 333 +# bbb = @compile MPI.Send(send_buf, dest, tag, comm) + +# # println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) + +# # println(@code_lowered bbb(send_buf, dest, tag, comm)) + +# println("\nlowered") +# println(@code_lowered bbb(send_buf, dest, tag, comm)) +# println("\ntyped") +# println(@code_typed bbb(send_buf, dest, tag, comm)) +# println("\nllvm") +# println(@code_llvm bbb(send_buf, dest, tag, comm)) + +# # elseif rank == 1 +# # # # src = 0 +# # # src = 555 +# # # # @jit MPI.Recv!(recv_buf, src, tag, comm) +# # # println(@code_hlo optimize=false MPI.Recv!(recv_buf, src, tag, comm)) +# # # println( recv_buf == send_buf ) +# end + + +MPI.Finalize() From ff8c4f5155e02f924d524d7d24ef683b49dce6dd Mon Sep 17 00:00:00 2001 From: romanlee Date: Tue, 9 Sep 2025 16:12:19 -0700 Subject: [PATCH 66/97] Add TODO --- ext/ReactantMPIExt/Ops.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index b3e3c108fc..4ba1fee0cf 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -493,6 +493,7 @@ function recv!( errcode.mlir_data = IR.result(ret, 1) recvbuf.mlir_data = IR.result(ret, 2) + # TODO is returning recvbuf the best choice here stylistically? return errcode, recvbuf end From 8c975bf7139b0dba30889ef590fb81202dc106d2 Mon Sep 17 00:00:00 2001 From: romanlee Date: Tue, 9 Sep 2025 16:39:30 -0700 Subject: [PATCH 67/97] Add test where Send and Recv! are compiled together in one function --- test/integration/mpi.jl | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index e467151146..79aed4b04c 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -1,5 +1,8 @@ using Test, MPI, Reactant +# # MPI only works on cpu currently --- is this the right way/place to enforce that? +# Reactant.set_default_backend("cpu") + MPI.Init() @testset "Comm_rank" begin @@ -28,12 +31,12 @@ end # test MPI.jl Send / Reactant Recv @testset "MPI.jl Send / Reactant Recv!" begin - send_buf = fill(1) + send_buf = ones(5) tag = 43 if rank == 0 MPI.Send(send_buf, comm; dest=1, tag=tag) elseif rank == 1 - recv_buf = ConcreteRArray(fill(12)) + recv_buf = ConcreteRArray(zeros(5)) source = 0 @jit MPI.Recv!(recv_buf, source, tag, comm) @test recv_buf == send_buf @@ -42,35 +45,52 @@ end # test Reactant Send / MPI.jl Recv @testset "Reactant Send / MPI.jl Recv!" begin - send_buf = ConcreteRArray(fill(1)) + send_buf = ConcreteRArray(ones(5)) tag = 43 if rank == 0 dest = 1 @jit MPI.Send(send_buf, dest, tag, comm) elseif rank == 1 - recv_buf = fill(12) + recv_buf = zeros(5) MPI.Recv!(recv_buf, comm; source=0, tag=tag) @test recv_buf == send_buf end end # test Reactant Send/Recv - @testset "Reactant Send / Recv!" begin - send_buf = ConcreteRArray(fill(1)) + @testset "Reactant Send / Recv! - compiled separately" begin + send_buf = ConcreteRArray(ones(5)) tag = 43 if rank == 0 - # Send: pass on cpu, pass on gpu dest = 1 @jit MPI.Send(send_buf, dest, tag, comm) elseif rank == 1 - # hang on cpu - # segfault on gpu upon trying to reference res - recv_buf = ConcreteRArray(fill(12)) + recv_buf = ConcreteRArray(zeros(5)) src = 0 @jit MPI.Recv!(recv_buf, src, tag, comm) @test recv_buf == send_buf end end + + @testset "Reactant Send / Recv! - compiled together" begin + send_buf = ConcreteRArray(ones(5)) + recv_buf = ConcreteRArray(zeros(5)) + tag = 43 + function sendrecv!(comm, rank, send_buf, recv_buf, tag) + if rank == 0 + dest = 1 + err_code = MPI.Send(send_buf, dest, tag, comm) # kinda hacky, but unfort have to return something otherwise julia optimizes this out @code_lowered + return err_code + elseif rank == 1 + src = 0 + MPI.Recv!(recv_buf, src, tag, comm) + return recv_buf + end + end + @jit sendrecv!(comm, rank, send_buf, recv_buf, tag) + rank==1 && @test recv_buf == send_buf + end + end MPI.Finalize() From 62657ad23a3f4e17dd7e0c9e3e091cc37fde3d51 Mon Sep 17 00:00:00 2001 From: romanlee Date: Tue, 9 Sep 2025 17:52:18 -0700 Subject: [PATCH 68/97] Update tests, send/recv-compiled-together test now working --- .../2025.09.mpi/test-isend-irecv.jl | 120 +++++++----------- .../2025.09.mpi/test-isend-irecv0.jl | 113 ++++------------- 2 files changed, 75 insertions(+), 158 deletions(-) diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl index 0c776b9362..edf9da9df6 100644 --- a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl +++ b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl @@ -10,116 +10,90 @@ rank = MPI.Comm_rank(comm) nranks = MPI.Comm_size(comm) -send_buf = ConcreteRArray([1, 2, 3, 4, 5]) # hangs with small bufs, Send no block, Recv no receive send, no bueno +# ---------------- +# Send/Recv! in one func +# ---------------- +send_buf = ConcreteRArray([1, 2, 3, 4, 5]) recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# send_buf = zeros(UInt8, 65536) # works with bufs > 65KB -# recv_buf = ones(UInt8, 65536) tag = 43 function aaa(comm, rank, send_buf, recv_buf, tag) if rank == 0 dest = 1 # ccall(:jl_breakpoint, Cvoid, (Any,), dest) - MPI.Send(send_buf, dest, tag, comm) + return MPI.Send(send_buf, dest, tag, comm) # kinda hacky, but have to return this otherwise julia optimizes this out elseif rank == 1 src = 0 + # return MPI.Recv!(recv_buf, src, tag, comm) MPI.Recv!(recv_buf, src, tag, comm) - println( recv_buf == send_buf ) + return recv_buf end - return nothing end -@jit aaa(comm, rank, send_buf, recv_buf, tag) -println("Rank $rank") +result = @jit aaa(comm, rank, send_buf, recv_buf, tag) -# rank==1 && sleep(5) -# println("\nRank $rank:\n", @code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) +if rank==0 + println("Rank $rank: $result") +elseif rank==1 + println("Rank $rank: $(result[2])") + println( recv_buf == send_buf ) +end +# # rank==1 && sleep(5) +# # # println("\nRank $rank:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) +# # println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# function aaa(comm, rank, send_buf, recv_buf, tag) -# if rank == 0 -# dest = 333 -# MPI.Send(send_buf, dest, tag, comm) -# elseif rank == 1 -# src = 555 -# MPI.Recv!(recv_buf, src, tag, comm) -# # println( recv_buf == send_buf ) -# end -# return nothing -# end -# # @jit aaa(comm, rank, send_buf, recv_buf, tag) -# rank==1 && sleep(5) -# # println("\nRank $rank:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) -# println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) - -# # # bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) -# # # if rank==0 -# # # println("\nlowered") -# # # println(@code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) -# # # println("\ntyped") -# # # println(@code_typed bbb(comm, rank, send_buf, recv_buf, tag)) -# # # println("\nllvm") -# # # println(@code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) -# # # end - - -# # works -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# if rank == 0 -# dest = 1 -# @jit MPI.Send(send_buf, dest, tag, comm) -# elseif rank == 1 -# src = 0 -# @jit MPI.Recv!(recv_buf, src, tag, comm) -# println( recv_buf == send_buf ) +# if rank==0 +# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) +# println("\ncode_hlo:\n", @code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) +# println("\ncode_xla:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) + +# bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) +# println("\nlowered:\n", @code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) +# println("\ntyped:\n", @code_typed bbb(comm, rank, send_buf, recv_buf, tag)) +# println("\nllvm:\n", @code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) # end + + +# ---------------- +# Send/Recv! compiled separately +# ---------------- +# # test: works # send_buf = ConcreteRArray([1, 2, 3, 4, 5]) # recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) # tag = 43 # if rank == 0 # dest = 1 -# # dest = 333 # @jit MPI.Send(send_buf, dest, tag, comm) -# # println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) # elseif rank == 1 # src = 0 -# # src = 555 # @jit MPI.Recv!(recv_buf, src, tag, comm) -# # println(@code_hlo optimize=false MPI.Recv!(recv_buf, src, tag, comm)) # println( recv_buf == send_buf ) # end - +# debug # send_buf = ConcreteRArray([1, 2, 3, 4, 5]) # recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) # tag = 43 # if rank == 0 -# dest = 333 -# bbb = @compile MPI.Send(send_buf, dest, tag, comm) - -# # println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) - -# # println(@code_lowered bbb(send_buf, dest, tag, comm)) +# dest = 1 -# println("\nlowered") -# println(@code_lowered bbb(send_buf, dest, tag, comm)) -# println("\ntyped") -# println(@code_typed bbb(send_buf, dest, tag, comm)) -# println("\nllvm") -# println(@code_llvm bbb(send_buf, dest, tag, comm)) +# # @jit MPI.Send(send_buf, dest, tag, comm) + +# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) +# println("\ncode_hlo:\n", @code_hlo MPI.Send(send_buf, dest, tag, comm)) +# println("\ncode_xla:\n", @code_xla MPI.Send(send_buf, dest, tag, comm)) + +# sss = @compile MPI.Send(send_buf, dest, tag, comm) +# println("\nlowered:\n", @code_lowered sss(send_buf, dest, tag, comm)) +# println("\ntyped:\n", @code_typed sss(send_buf, dest, tag, comm)) +# println("\nllvm:\n", @code_llvm sss(send_buf, dest, tag, comm)) # # elseif rank == 1 -# # # # src = 0 -# # # src = 555 -# # # # @jit MPI.Recv!(recv_buf, src, tag, comm) -# # # println(@code_hlo optimize=false MPI.Recv!(recv_buf, src, tag, comm)) -# # # println( recv_buf == send_buf ) +# # src = 0 +# # @jit MPI.Recv!(recv_buf, src, tag, comm) +# # println( recv_buf == send_buf ) # end diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv0.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv0.jl index d1b9c68d85..5913ecf35c 100644 --- a/roman-temp-debug/2025.09.mpi/test-isend-irecv0.jl +++ b/roman-temp-debug/2025.09.mpi/test-isend-irecv0.jl @@ -1,4 +1,4 @@ -using Test, MPI, Reactant +using Test, MPI, Reactant, InteractiveUtils Reactant.set_default_backend("cpu") # Reactant.set_default_backend("gpu") @@ -9,118 +9,61 @@ comm = MPI.COMM_WORLD rank = MPI.Comm_rank(comm) nranks = MPI.Comm_size(comm) -println("BEFORE: rank $rank") - -# # runs without crashing -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# tag = 42 -# dest = 1 -# function Isend_Wait(send_buf, dest, tag, comm) -# req = MPI.Isend(send_buf, dest, tag, comm) -# MPI.Wait(req) -# return nothing -# end -# @jit Isend_Wait(send_buf, dest, tag, comm) +send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +tag = 42 +dest = 1 +function aaa(send_buf, dest, tag, comm) + req = MPI.Isend(send_buf, dest, tag, comm) + errcode = MPI.Wait(req) + return errcode +end +@jit aaa(send_buf, dest, tag, comm) -# # runs without crashing # recv_buf = ConcreteRArray([1, 2, 3, 4, 5]) # tag = 42 # src = 1 -# function Irecv!_Wait(recv_buf, src, tag, comm) +# function aaa(recv_buf, src, tag, comm) # req = MPI.Irecv!(recv_buf, src, tag, comm) # MPI.Wait(req) # return nothing # end -# @jit Irecv!_Wait(recv_buf, src, tag, comm) +# # @jit Irecv!_Wait(recv_buf, src, tag, comm) + +# # if rank==0 +# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(recv_buf, src, tag, comm)) +# println("\ncode_hlo:\n", @code_hlo aaa(recv_buf, src, tag, comm)) +# println("\ncode_xla:\n", @code_xla aaa(recv_buf, src, tag, comm)) + +# bbb = @compile aaa(recv_buf, src, tag, comm) +# println("\nlowered:\n", @code_lowered bbb(recv_buf, src, tag, comm)) +# println("\ntyped:\n", @code_typed bbb(recv_buf, src, tag, comm)) +# println("\nllvm:\n", @code_llvm bbb(recv_buf, src, tag, comm)) +# # end -# # recv_buf not modified # send_buf = ConcreteRArray([1, 2, 3, 4, 5]) # recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) # function Isend_Irecv!(comm, rank, send_buf, recv_buf) # if rank==0 -# println("rank 0") - # dest = 1 # tag = 42 - # req = MPI.Isend(send_buf, dest, tag, comm) -# MPI.Wait(req) +# err = MPI.Wait(req) +# return err # elseif rank==1 -# println("rank 1") - # src = 0 # tag = 42 - # req = MPI.Irecv!(recv_buf, src, tag, comm) # MPI.Wait(req) +# return recv_buf # end - -# return recv_buf -# end -# recv_buf = @jit Isend_Irecv!(comm, rank, send_buf, recv_buf) -# println(recv_buf) - - -# # hangs -# send_buf = ConcreteRArray(fill(1)) -# recv_buf = ConcreteRArray(fill(12)) -# tag = 43 -# function aaa(comm, rank, send_buf, recv_buf, tag) -# if rank == 0 -# dest = 1 -# MPI.Send(send_buf, dest, tag, comm) -# elseif rank == 1 -# src = 0 -# MPI.Recv!(recv_buf, src, tag, comm) -# # println( recv_buf == send_buf ) -# end -# return nothing -# end -# # @jit aaa(comm, rank, send_buf, recv_buf, tag) -# # display(@code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) -# display(@code_xla aaa(comm, rank, send_buf, recv_buf, tag)) - - -# # works -# send_buf = ConcreteRArray(fill(1)) -# recv_buf = ConcreteRArray(fill(12)) -# tag = 43 -# if rank == 0 -# dest = 1 -# @jit MPI.Send(send_buf, dest, tag, comm) -# elseif rank == 1 -# src = 0 -# @jit MPI.Recv!(recv_buf, src, tag, comm) -# println( recv_buf == send_buf ) # end +# result = @jit Isend_Irecv!(comm, rank, send_buf, recv_buf) +# println("Rank $rank: $result") -# # hangs debug -# send_buf = ConcreteRArray(fill(1)) -# recv_buf = ConcreteRArray(fill(12)) -# tag = 43 -# function aaa(comm, rank, send_buf, recv_buf, tag) -# # if rank == 0 -# dest = 1 -# MPI.Send(send_buf, dest, tag, comm) -# # elseif rank == 1 -# src = 0 -# MPI.Recv!(recv_buf, src, tag, comm) -# # println( recv_buf == send_buf ) -# # end -# return nothing -# end -# # @jit aaa(comm, rank, send_buf, recv_buf, tag) -# # display(@code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) -# display(@code_xla aaa(comm, rank, send_buf, recv_buf, tag)) - - - - -println("AFTER: rank $rank") MPI.Finalize() From 43df184d6596ff79b4b268715cefad62a1b67891 Mon Sep 17 00:00:00 2001 From: romanlee Date: Wed, 10 Sep 2025 11:14:21 -0700 Subject: [PATCH 69/97] Cleanup testing mess --- .../2025.09.mpi/test-isend-irecv.jl | 150 ++++++----- .../2025.09.mpi/test-isend-irecv0.jl | 69 ----- .../2025.09.mpi/test-isend-irecv1.jl | 252 ------------------ .../2025.09.mpi/test-send-recv2.jl | 100 +++++++ 4 files changed, 180 insertions(+), 391 deletions(-) delete mode 100644 roman-temp-debug/2025.09.mpi/test-isend-irecv0.jl delete mode 100644 roman-temp-debug/2025.09.mpi/test-isend-irecv1.jl create mode 100644 roman-temp-debug/2025.09.mpi/test-send-recv2.jl diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl index edf9da9df6..dc52af22d0 100644 --- a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl +++ b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl @@ -10,91 +10,101 @@ rank = MPI.Comm_rank(comm) nranks = MPI.Comm_size(comm) -# ---------------- -# Send/Recv! in one func -# ---------------- -send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -tag = 43 -function aaa(comm, rank, send_buf, recv_buf, tag) - if rank == 0 - dest = 1 - # ccall(:jl_breakpoint, Cvoid, (Any,), dest) - return MPI.Send(send_buf, dest, tag, comm) # kinda hacky, but have to return this otherwise julia optimizes this out - elseif rank == 1 - src = 0 - # return MPI.Recv!(recv_buf, src, tag, comm) - MPI.Recv!(recv_buf, src, tag, comm) - return recv_buf - end -end +# # -------------------------- +# # test MPI.jl Isend / Irecv! +# # -------------------------- +# # Skip test if not enough processes +# if nranks < 2 +# @error "Need at least 2 MPI processes for Isend/Irecv test" +# end + +# send_buf = [1, 2, 3, 4, 5] +# tag = 42 +# if rank == 0 +# dest = 1 -result = @jit aaa(comm, rank, send_buf, recv_buf, tag) +# req_send = MPI.Isend(send_buf, dest, tag, comm) -if rank==0 - println("Rank $rank: $result") -elseif rank==1 - println("Rank $rank: $(result[2])") - println( recv_buf == send_buf ) -end +# println("Rank 0: Waiting...") -# # rank==1 && sleep(5) -# # # println("\nRank $rank:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) -# # println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) +# MPI.Wait(req_send) -# if rank==0 -# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) -# println("\ncode_hlo:\n", @code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) -# println("\ncode_xla:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) +# println("Rank 0: Sent") + +# elseif rank == 1 +# recv_buf = Vector{Int}(undef, 5) +# source = 0 + +# req_recv = MPI.Irecv!(recv_buf, source, tag, comm) + +# println("Rank 1: Waiting...") + +# status = MPI.Wait(req_recv) + +# println( "Rank 1: Received: $(recv_buf == send_buf)" ) +# # @test MPI.Get_source(status) == 0 +# # @test MPI.Get_tag(status) == 42 -# bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) -# println("\nlowered:\n", @code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) -# println("\ntyped:\n", @code_typed bbb(comm, rank, send_buf, recv_buf, tag)) -# println("\nllvm:\n", @code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) # end +# # -------------------------- -# ---------------- -# Send/Recv! compiled separately -# ---------------- -# # test: works -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# if rank == 0 -# dest = 1 -# @jit MPI.Send(send_buf, dest, tag, comm) -# elseif rank == 1 -# src = 0 -# @jit MPI.Recv!(recv_buf, src, tag, comm) -# println( recv_buf == send_buf ) +send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +tag = 42 +dest = 1 +function aaa(send_buf, dest, tag, comm) + req = MPI.Isend(send_buf, dest, tag, comm) + errcode = MPI.Wait(req) + return errcode +end +@jit aaa(send_buf, dest, tag, comm) + + + +# recv_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# tag = 42 +# src = 1 +# function aaa(recv_buf, src, tag, comm) +# req = MPI.Irecv!(recv_buf, src, tag, comm) +# MPI.Wait(req) +# return nothing # end +# # @jit Irecv!_Wait(recv_buf, src, tag, comm) + +# # if rank==0 +# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(recv_buf, src, tag, comm)) +# println("\ncode_hlo:\n", @code_hlo aaa(recv_buf, src, tag, comm)) +# println("\ncode_xla:\n", @code_xla aaa(recv_buf, src, tag, comm)) + +# bbb = @compile aaa(recv_buf, src, tag, comm) +# println("\nlowered:\n", @code_lowered bbb(recv_buf, src, tag, comm)) +# println("\ntyped:\n", @code_typed bbb(recv_buf, src, tag, comm)) +# println("\nllvm:\n", @code_llvm bbb(recv_buf, src, tag, comm)) +# # end -# debug # send_buf = ConcreteRArray([1, 2, 3, 4, 5]) # recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# if rank == 0 -# dest = 1 - -# # @jit MPI.Send(send_buf, dest, tag, comm) - -# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) -# println("\ncode_hlo:\n", @code_hlo MPI.Send(send_buf, dest, tag, comm)) -# println("\ncode_xla:\n", @code_xla MPI.Send(send_buf, dest, tag, comm)) - -# sss = @compile MPI.Send(send_buf, dest, tag, comm) -# println("\nlowered:\n", @code_lowered sss(send_buf, dest, tag, comm)) -# println("\ntyped:\n", @code_typed sss(send_buf, dest, tag, comm)) -# println("\nllvm:\n", @code_llvm sss(send_buf, dest, tag, comm)) - -# # elseif rank == 1 -# # src = 0 -# # @jit MPI.Recv!(recv_buf, src, tag, comm) -# # println( recv_buf == send_buf ) +# function Isend_Irecv!(comm, rank, send_buf, recv_buf) +# if rank==0 +# dest = 1 +# tag = 42 +# req = MPI.Isend(send_buf, dest, tag, comm) +# err = MPI.Wait(req) +# return err +# elseif rank==1 +# src = 0 +# tag = 42 +# req = MPI.Irecv!(recv_buf, src, tag, comm) +# MPI.Wait(req) +# return recv_buf +# end # end +# result = @jit Isend_Irecv!(comm, rank, send_buf, recv_buf) +# println("Rank $rank: $result") + + MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv0.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv0.jl deleted file mode 100644 index 5913ecf35c..0000000000 --- a/roman-temp-debug/2025.09.mpi/test-isend-irecv0.jl +++ /dev/null @@ -1,69 +0,0 @@ -using Test, MPI, Reactant, InteractiveUtils - -Reactant.set_default_backend("cpu") -# Reactant.set_default_backend("gpu") - -MPI.Init() - -comm = MPI.COMM_WORLD -rank = MPI.Comm_rank(comm) -nranks = MPI.Comm_size(comm) - -send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -tag = 42 -dest = 1 -function aaa(send_buf, dest, tag, comm) - req = MPI.Isend(send_buf, dest, tag, comm) - errcode = MPI.Wait(req) - return errcode -end -@jit aaa(send_buf, dest, tag, comm) - - - -# recv_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# tag = 42 -# src = 1 -# function aaa(recv_buf, src, tag, comm) -# req = MPI.Irecv!(recv_buf, src, tag, comm) -# MPI.Wait(req) -# return nothing -# end -# # @jit Irecv!_Wait(recv_buf, src, tag, comm) - -# # if rank==0 -# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(recv_buf, src, tag, comm)) -# println("\ncode_hlo:\n", @code_hlo aaa(recv_buf, src, tag, comm)) -# println("\ncode_xla:\n", @code_xla aaa(recv_buf, src, tag, comm)) - -# bbb = @compile aaa(recv_buf, src, tag, comm) -# println("\nlowered:\n", @code_lowered bbb(recv_buf, src, tag, comm)) -# println("\ntyped:\n", @code_typed bbb(recv_buf, src, tag, comm)) -# println("\nllvm:\n", @code_llvm bbb(recv_buf, src, tag, comm)) -# # end - - -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# function Isend_Irecv!(comm, rank, send_buf, recv_buf) -# if rank==0 -# dest = 1 -# tag = 42 -# req = MPI.Isend(send_buf, dest, tag, comm) -# err = MPI.Wait(req) -# return err -# elseif rank==1 -# src = 0 -# tag = 42 -# req = MPI.Irecv!(recv_buf, src, tag, comm) -# MPI.Wait(req) -# return recv_buf -# end -# end -# result = @jit Isend_Irecv!(comm, rank, send_buf, recv_buf) -# println("Rank $rank: $result") - - - - -MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv1.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv1.jl deleted file mode 100644 index 12b48afc6a..0000000000 --- a/roman-temp-debug/2025.09.mpi/test-isend-irecv1.jl +++ /dev/null @@ -1,252 +0,0 @@ -using Test, MPI, Reactant, InteractiveUtils - -Reactant.set_default_backend("cpu") -# Reactant.set_default_backend("gpu") - -MPI.Init() - -comm = MPI.COMM_WORLD -rank = MPI.Comm_rank(comm) -nranks = MPI.Comm_size(comm) - -# # -------------------------- -# # test MPI.jl Isend / Irecv! -# # -------------------------- -# # Skip test if not enough processes -# if nranks < 2 -# @error "Need at least 2 MPI processes for Isend/Irecv test" -# end - -# send_buf = [1, 2, 3, 4, 5] -# tag = 42 -# if rank == 0 -# dest = 1 - -# req_send = MPI.Isend(send_buf, dest, tag, comm) - -# println("Rank 0: Waiting...") - -# MPI.Wait(req_send) - -# println("Rank 0: Sent") - -# elseif rank == 1 -# recv_buf = Vector{Int}(undef, 5) -# source = 0 - -# req_recv = MPI.Irecv!(recv_buf, source, tag, comm) - -# println("Rank 1: Waiting...") - -# status = MPI.Wait(req_recv) - -# println( "Rank 1: Received: $(recv_buf == send_buf)" ) -# # @test MPI.Get_source(status) == 0 -# # @test MPI.Get_tag(status) == 42 - -# end -# # -------------------------- - - -# -------------------------- -# # test Reactant Isend -# -------------------------- -# if nranks < 2 -# @error "Need at least 2 MPI processes for Isend/Irecv test" -# end -# -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# tag = 42 -# if rank == 0 -# dest = 1 - -# req_send = @jit MPI.Isend(send_buf, dest, tag, comm) - -# MPI.Wait(req_send) - -# elseif rank == 1 -# recv_buf = Vector{Int}(undef, 5) -# source = 0 - -# req_recv = MPI.Irecv!(recv_buf, source, tag, comm) - -# status = MPI.Wait(req_recv) - -# println( recv_buf == send_buf ) -# # @test MPI.Get_source(status) == 0 -# # @test MPI.Get_tag(status) == 42 - -# end - - -# -------------------------- -# debug -# -------------------------- -# # runs without crashing -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# tag = 42 -# dest = 1 -# function Isend_Wait(send_buf, dest, tag, comm) -# req = MPI.Isend(send_buf, dest, tag, comm) -# MPI.Wait(req) -# return nothing -# end -# @jit Isend_Wait(send_buf, dest, tag, comm) -# println(@code_hlo optimize=false Isend_Wait(send_buf, dest, tag, comm)) - - -# # runs without crashing -# recv_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# tag = 42 -# src = 1 -# # function Irecv_Wait(recv_buf, src, tag, comm) -# function Irecv_Wait(recv_buf, src, tag, comm) -# req = MPI.Irecv!(recv_buf, src, tag, comm) -# MPI.Wait(req) -# return nothing -# end -# # @jit Irecv_Wait(recv_buf, src, tag, comm) -# println(@code_hlo optimize=false Irecv_Wait(recv_buf, src, tag, comm)) - - -# # recv_buf not modified -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# function Isend_Irecv!(comm, rank, send_buf, recv_buf) -# if rank==0 -# dest = 1 -# tag = 42 -# req = MPI.Isend(send_buf, dest, tag, comm) -# MPI.Wait(req) -# elseif rank==1 -# src = 0 -# tag = 42 -# req = MPI.Irecv!(recv_buf, src, tag, comm) -# MPI.Wait(req) -# end - -# return recv_buf -# end -# # recv_buf = @jit Isend_Irecv!(comm, rank, send_buf, recv_buf) -# -# rank==1 && sleep(3) -# println("\nRank: $rank") -# println(@code_hlo optimize=false Isend_Irecv!(comm, rank, send_buf, recv_buf)) - - -# # hangs -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# function aaa(comm, rank, send_buf, recv_buf, tag) -# if rank == 0 -# dest = 1 -# # ccall(:jl_breakpoint, Cvoid, (Any,), dest) -# MPI.Send(send_buf, dest, tag, comm) -# # while true -# # sleep(5) -# # end -# elseif rank == 1 -# src = 0 -# # ccall(:jl_breakpoint, Cvoid, (Any,), src) -# MPI.Recv!(recv_buf, src, tag, comm) -# # println( recv_buf == send_buf ) -# # while true -# # sleep(5) -# # end -# end -# return nothing -# end -# @jit aaa(comm, rank, send_buf, recv_buf, tag) -# # println("\nRank: $rank") - - -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# function aaa(comm, rank, send_buf, recv_buf, tag) -# if rank == 0 -# dest = 333 -# MPI.Send(send_buf, dest, tag, comm) -# elseif rank == 1 -# src = 555 -# MPI.Recv!(recv_buf, src, tag, comm) -# # println( recv_buf == send_buf ) -# end -# return nothing -# end -# # @jit aaa(comm, rank, send_buf, recv_buf, tag) -# rank==1 && sleep(5) -# # println("\nRank $rank:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) -# println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) - -# # # bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) -# # # if rank==0 -# # # println("\nlowered") -# # # println(@code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) -# # # println("\ntyped") -# # # println(@code_typed bbb(comm, rank, send_buf, recv_buf, tag)) -# # # println("\nllvm") -# # # println(@code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) -# # # end - - -# # works -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# if rank == 0 -# dest = 1 -# @jit MPI.Send(send_buf, dest, tag, comm) -# elseif rank == 1 -# src = 0 -# @jit MPI.Recv!(recv_buf, src, tag, comm) -# println( recv_buf == send_buf ) -# end - -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# if rank == 0 -# dest = 1 -# # dest = 333 -# @jit MPI.Send(send_buf, dest, tag, comm) -# # println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) -# elseif rank == 1 -# src = 0 -# # src = 555 -# @jit MPI.Recv!(recv_buf, src, tag, comm) -# # println(@code_hlo optimize=false MPI.Recv!(recv_buf, src, tag, comm)) -# println( recv_buf == send_buf ) -# end - - - -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# if rank == 0 -# dest = 333 -# bbb = @compile MPI.Send(send_buf, dest, tag, comm) - -# # println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) - -# # println(@code_lowered bbb(send_buf, dest, tag, comm)) - -# println("\nlowered") -# println(@code_lowered bbb(send_buf, dest, tag, comm)) -# println("\ntyped") -# println(@code_typed bbb(send_buf, dest, tag, comm)) -# println("\nllvm") -# println(@code_llvm bbb(send_buf, dest, tag, comm)) - -# # elseif rank == 1 -# # # # src = 0 -# # # src = 555 -# # # # @jit MPI.Recv!(recv_buf, src, tag, comm) -# # # println(@code_hlo optimize=false MPI.Recv!(recv_buf, src, tag, comm)) -# # # println( recv_buf == send_buf ) -# end - - -MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/test-send-recv2.jl b/roman-temp-debug/2025.09.mpi/test-send-recv2.jl new file mode 100644 index 0000000000..edf9da9df6 --- /dev/null +++ b/roman-temp-debug/2025.09.mpi/test-send-recv2.jl @@ -0,0 +1,100 @@ +using Test, MPI, Reactant, InteractiveUtils + +Reactant.set_default_backend("cpu") +# Reactant.set_default_backend("gpu") + +MPI.Init() + +comm = MPI.COMM_WORLD +rank = MPI.Comm_rank(comm) +nranks = MPI.Comm_size(comm) + + +# ---------------- +# Send/Recv! in one func +# ---------------- +send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +tag = 43 +function aaa(comm, rank, send_buf, recv_buf, tag) + if rank == 0 + dest = 1 + # ccall(:jl_breakpoint, Cvoid, (Any,), dest) + return MPI.Send(send_buf, dest, tag, comm) # kinda hacky, but have to return this otherwise julia optimizes this out + elseif rank == 1 + src = 0 + # return MPI.Recv!(recv_buf, src, tag, comm) + MPI.Recv!(recv_buf, src, tag, comm) + return recv_buf + end +end + +result = @jit aaa(comm, rank, send_buf, recv_buf, tag) + +if rank==0 + println("Rank $rank: $result") +elseif rank==1 + println("Rank $rank: $(result[2])") + println( recv_buf == send_buf ) +end + +# # rank==1 && sleep(5) +# # # println("\nRank $rank:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) +# # println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) + +# if rank==0 +# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) +# println("\ncode_hlo:\n", @code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) +# println("\ncode_xla:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) + +# bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) +# println("\nlowered:\n", @code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) +# println("\ntyped:\n", @code_typed bbb(comm, rank, send_buf, recv_buf, tag)) +# println("\nllvm:\n", @code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) +# end + + + +# ---------------- +# Send/Recv! compiled separately +# ---------------- +# # test: works +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# if rank == 0 +# dest = 1 +# @jit MPI.Send(send_buf, dest, tag, comm) +# elseif rank == 1 +# src = 0 +# @jit MPI.Recv!(recv_buf, src, tag, comm) +# println( recv_buf == send_buf ) +# end + + +# debug +# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) +# tag = 43 +# if rank == 0 +# dest = 1 + +# # @jit MPI.Send(send_buf, dest, tag, comm) + +# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) +# println("\ncode_hlo:\n", @code_hlo MPI.Send(send_buf, dest, tag, comm)) +# println("\ncode_xla:\n", @code_xla MPI.Send(send_buf, dest, tag, comm)) + +# sss = @compile MPI.Send(send_buf, dest, tag, comm) +# println("\nlowered:\n", @code_lowered sss(send_buf, dest, tag, comm)) +# println("\ntyped:\n", @code_typed sss(send_buf, dest, tag, comm)) +# println("\nllvm:\n", @code_llvm sss(send_buf, dest, tag, comm)) + +# # elseif rank == 1 +# # src = 0 +# # @jit MPI.Recv!(recv_buf, src, tag, comm) +# # println( recv_buf == send_buf ) +# end + + +MPI.Finalize() From 4aa4371e3099325ff7e0392d18407024fc1de7ed Mon Sep 17 00:00:00 2001 From: romanlee Date: Wed, 10 Sep 2025 12:40:28 -0700 Subject: [PATCH 70/97] Fix typo in Ops.irecv!(), return error code from Wait Isend/Irecv test still not working --- ext/ReactantMPIExt/Ops.jl | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 4ba1fee0cf..d50550def1 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -519,7 +519,7 @@ function irecv!( buf::TracedRArray, tag::TracedRNumber, src::TracedRNumber; - location=mlir_stacktrace("mpi.isend", @__FILE__, @__LINE__), + location=mlir_stacktrace("mpi.irecv", @__FILE__, @__LINE__), ) T = Reactant.unwrapped_eltype(buf) mpi_datatype = convert_julia_type_to_mpi_datatype(T) @@ -585,27 +585,43 @@ function wait( # # TODO Temporarily commented out bc otherwise can't compile together with any other # # func that tries to inject the same thing (e.g., isend, irecv) # IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!("MPI_Wait", "llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32") #! format: off IR.inject!(sym_name, """ - func.func @$sym_name(%req : !llvm.ptr) -> () { + func.func @$sym_name(%errcode : !llvm.ptr, %req : !llvm.ptr) -> () { %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr - %errcode = llvm.call @MPI_Wait(%req, %comm) : (!llvm.ptr, !llvm.ptr) -> (i32) + %res = llvm.call @MPI_Wait(%req, %comm) : (!llvm.ptr, !llvm.ptr) -> (i32) + llvm.store %res, %errcode : i32, !llvm.ptr func.return } """) #! format: on - enzymexla.jit_call( - IR.Value[req.mlir_data]; + errcode = Reactant.Ops.constant(fill(Cint(0))) + + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + ), + ), + ]) + + ret = enzymexla.jit_call( + IR.Value[errcode.mlir_data, req.mlir_data]; fn=sym_attr, - result_0=IR.Type[], + # result_0=IR.Type[], + result_0=IR.Type[mlir_type(errcode)], location, - output_operand_aliases=IR.Attribute(IR.Attribute[]), + # output_operand_aliases=IR.Attribute(IR.Attribute[]), + output_operand_aliases=output_operand_aliases, ) - return nothing + # return nothing + errcode.mlir_data = IR.result(ret) + return errcode end function inject_mpi_op!(op) From bf8ecccbaf1a3c4052ba9368752ae3f30512e87b Mon Sep 17 00:00:00 2001 From: romanlee Date: Wed, 10 Sep 2025 15:33:47 -0700 Subject: [PATCH 71/97] Update isend/irecv tests --- .../2025.09.mpi/test-isend-irecv.jl | 100 +++++++++--------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl index dc52af22d0..5f39bd54fd 100644 --- a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl +++ b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl @@ -19,92 +19,96 @@ nranks = MPI.Comm_size(comm) # end # send_buf = [1, 2, 3, 4, 5] +# send_buf = ones(5) +# recv_buf = zeros(5) # tag = 42 # if rank == 0 # dest = 1 - # req_send = MPI.Isend(send_buf, dest, tag, comm) - # println("Rank 0: Waiting...") - # MPI.Wait(req_send) - # println("Rank 0: Sent") - # elseif rank == 1 -# recv_buf = Vector{Int}(undef, 5) # source = 0 - # req_recv = MPI.Irecv!(recv_buf, source, tag, comm) - # println("Rank 1: Waiting...") - # status = MPI.Wait(req_recv) - # println( "Rank 1: Received: $(recv_buf == send_buf)" ) # # @test MPI.Get_source(status) == 0 # # @test MPI.Get_tag(status) == 42 - # end # # -------------------------- -send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -tag = 42 -dest = 1 -function aaa(send_buf, dest, tag, comm) - req = MPI.Isend(send_buf, dest, tag, comm) - errcode = MPI.Wait(req) - return errcode -end -@jit aaa(send_buf, dest, tag, comm) +# send_buf = ConcreteRArray(ones(5)) +# tag = 42 +# dest = 1 +# function aaa(send_buf, dest, tag, comm) +# req = MPI.Isend(send_buf, dest, tag, comm) +# errcode = MPI.Wait(req) +# return errcode +# end +# @jit aaa(send_buf, dest, tag, comm) +# rank==1 && sleep(10) +# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(send_buf, dest, tag, comm)) +# println("\ncode_hlo:\n", @code_hlo aaa(send_buf, dest, tag, comm)) +# println("\ncode_xla:\n", @code_xla aaa(send_buf, dest, tag, comm)) +# bbb = @compile aaa(send_buf, dest, tag, comm) +# println("\nlowered:\n", @code_lowered bbb(send_buf, dest, tag, comm)) +# # println("\ntyped:\n", @code_typed bbb(send_buf, dest, tag, comm)) +# # println("\nllvm:\n", @code_llvm bbb(send_buf, dest, tag, comm)) -# recv_buf = ConcreteRArray([1, 2, 3, 4, 5]) +# recv_buf = ConcreteRArray(zeros(5)) # tag = 42 # src = 1 # function aaa(recv_buf, src, tag, comm) # req = MPI.Irecv!(recv_buf, src, tag, comm) -# MPI.Wait(req) -# return nothing +# errcode = MPI.Wait(req) +# return errcode, recv_buf # end -# # @jit Irecv!_Wait(recv_buf, src, tag, comm) +# @jit aaa(recv_buf, src, tag, comm) -# # if rank==0 -# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(recv_buf, src, tag, comm)) -# println("\ncode_hlo:\n", @code_hlo aaa(recv_buf, src, tag, comm)) -# println("\ncode_xla:\n", @code_xla aaa(recv_buf, src, tag, comm)) +# rank==1 && sleep(10) +# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(recv_buf, src, tag, comm)) +# println("\ncode_hlo:\n", @code_hlo aaa(recv_buf, src, tag, comm)) +# println("\ncode_xla:\n", @code_xla aaa(recv_buf, src, tag, comm)) +# bbb = @compile aaa(recv_buf, src, tag, comm) +# println("\nlowered:\n", @code_lowered bbb(recv_buf, src, tag, comm)) +# # println("\ntyped:\n", @code_typed bbb(recv_buf, src, tag, comm)) +# # println("\nllvm:\n", @code_llvm bbb(recv_buf, src, tag, comm)) -# bbb = @compile aaa(recv_buf, src, tag, comm) -# println("\nlowered:\n", @code_lowered bbb(recv_buf, src, tag, comm)) -# println("\ntyped:\n", @code_typed bbb(recv_buf, src, tag, comm)) -# println("\nllvm:\n", @code_llvm bbb(recv_buf, src, tag, comm)) -# # end - -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# function Isend_Irecv!(comm, rank, send_buf, recv_buf) +# send_buf = ConcreteRArray(ones(5)) +# recv_buf = ConcreteRArray(zeros(5)) +# tag = 42 +# function aaa(send_buf, recv_buf, rank, tag, comm) # if rank==0 # dest = 1 -# tag = 42 # req = MPI.Isend(send_buf, dest, tag, comm) -# err = MPI.Wait(req) -# return err +# # errcode = MPI.Wait(req) +# # return errcode +# return nothing # elseif rank==1 -# src = 0 -# tag = 42 +# src = 1 # req = MPI.Irecv!(recv_buf, src, tag, comm) -# MPI.Wait(req) -# return recv_buf +# errcode = MPI.Wait(req) +# return errcode, recv_buf # end # end -# result = @jit Isend_Irecv!(comm, rank, send_buf, recv_buf) -# println("Rank $rank: $result") - - +# @jit aaa(send_buf, recv_buf, rank, tag, comm) +# println("Rank $rank returned") + +# # rank==1 && sleep(10) +# # println("\n$rank: code_hlo optimize=false:\n", @code_hlo optimize=false aaa(send_buf, recv_buf, rank, tag, comm)) +# # println("\n$rank: code_hlo:\n", @code_hlo aaa(send_buf, recv_buf, rank, tag, comm)) +# # println("\n$rank: code_xla:\n", @code_xla aaa(send_buf, recv_buf, rank, tag, comm)) +# # bbb = @compile aaa(send_buf, recv_buf, rank, tag, comm) +# # println("\n$rank: lowered:\n", @code_lowered bbb(send_buf, recv_buf, rank, tag, comm)) +# # println("\n$rank: typed:\n", @code_typed bbb(send_buf, recv_buf, rank, tag, comm)) +# # println("\n$rank: llvm:\n", @code_llvm bbb(send_buf, recv_buf, rank, tag, comm)) MPI.Finalize() From d1ca937afef7084796bb6ba7aa2041f4e48e37a7 Mon Sep 17 00:00:00 2001 From: romanlee Date: Thu, 11 Sep 2025 16:00:01 -0700 Subject: [PATCH 72/97] irecv!: Need to return recvbuf and modify mlir data I think --- ext/ReactantMPIExt/Ops.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index d50550def1..aee8fbfc94 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -556,7 +556,12 @@ function irecv!( output_operand_aliases = IR.Attribute([ IR.Attribute( MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), 0, C_NULL, 4, 0, C_NULL + MLIR.IR.context(), 1, Ref{Int64}(0), 0, 0, C_NULL + ), + ), + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 1, Ref{Int64}(1), 4, 0, C_NULL ), ), ]) @@ -566,14 +571,14 @@ function irecv!( buf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data, request.mlir_data ]; fn=sym_attr, - result_0=IR.Type[mlir_type(request)], + result_0=[mlir_type(buf), mlir_type(request)], output_operand_aliases=output_operand_aliases, location, ) - # return TracedRNumber - request.mlir_data = IR.result(ret) - return request + buf.mlir_data = IR.result(ret, 1) + request.mlir_data = IR.result(ret, 2) + return request # we return a TracedRNumber, converted to TracedRequest in Overrides.jl end function wait( @@ -612,14 +617,11 @@ function wait( ret = enzymexla.jit_call( IR.Value[errcode.mlir_data, req.mlir_data]; fn=sym_attr, - # result_0=IR.Type[], result_0=IR.Type[mlir_type(errcode)], location, - # output_operand_aliases=IR.Attribute(IR.Attribute[]), output_operand_aliases=output_operand_aliases, ) - # return nothing errcode.mlir_data = IR.result(ret) return errcode end From 7481fcd15d39482dd3f64cf913aec883b7525afc Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Mon, 15 Sep 2025 20:18:05 -0500 Subject: [PATCH 73/97] Fix typo in MPI wait --- ext/ReactantMPIExt/Ops.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index aee8fbfc94..44dba018fa 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -594,10 +594,12 @@ function wait( IR.inject!("MPI_Wait", "llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32") #! format: off + # int MPI_Wait(MPI_Request* request, MPI_Status* status) IR.inject!(sym_name, """ func.func @$sym_name(%errcode : !llvm.ptr, %req : !llvm.ptr) -> () { - %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr - %res = llvm.call @MPI_Wait(%req, %comm) : (!llvm.ptr, !llvm.ptr) -> (i32) + %c1_i32 = arith.constant 1 : i32 + %status = llvm.alloca %c1_i32 x !llvm.array<4 x i32> : (i32) -> !llvm.ptr + %res = llvm.call @MPI_Wait(%req, %status) : (!llvm.ptr, !llvm.ptr) -> (i32) llvm.store %res, %errcode : i32, !llvm.ptr func.return } From 0395fb3933f2deb87a13c013b94885069fe13496 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Mon, 15 Sep 2025 20:30:17 -0500 Subject: [PATCH 74/97] Add launcher.jl --- roman-temp-debug/2025.09.mpi/launcher.jl | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 roman-temp-debug/2025.09.mpi/launcher.jl diff --git a/roman-temp-debug/2025.09.mpi/launcher.jl b/roman-temp-debug/2025.09.mpi/launcher.jl new file mode 100644 index 0000000000..7cd95dddf9 --- /dev/null +++ b/roman-temp-debug/2025.09.mpi/launcher.jl @@ -0,0 +1,4 @@ +# launcher.jl +# usage e.g.: julia launcher.jl 1 test-isend-irecv.jl +using MPI +run(`$(MPI.mpiexec()) -n $(ARGS[1]) julia --project $(ARGS[2])`) From a1b9218bb975dde0b7a3ebf03146fc41bf845004 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Wed, 17 Sep 2025 16:21:14 -0500 Subject: [PATCH 75/97] Don't need this until we need to let a Request cross compile barrier --- ext/ReactantMPIExt/ReactantMPIExt.jl | 7 ++++--- src/Reactant.jl | 3 +-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index f69573a622..49cfdf9d46 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -36,9 +36,9 @@ function Distributed.get_local_process_id(::Distributed.MPIEnvDetector) end function __init__() - # TODO improve this, temporary hack - # when you fix it, remember to possibly make TracedType const again - Reactant.TracedType = Union{Reactant.TracedRArray,Reactant.TracedRNumber,Reactant.MissingTracedValue,TracedRequest} + # # TODO improve this, temporary hack + # # when you fix it, remember to possibly make TracedType const again + # Reactant.TracedType = Union{Reactant.TracedRArray,Reactant.TracedRNumber,Reactant.MissingTracedValue,TracedRequest} # TODO maybe it's more efficient if we use `RTLD_NOW` instead of `RTLD_LAZY`? libmpi_handle = Libdl.dlopen(MPI.API.libmpi, RTLD_LAZY | RTLD_GLOBAL) @@ -255,6 +255,7 @@ Reactant.TracedUtils.get_mlir_data(x::TracedRequest) = x.mlir_data # return MLIR.IR.TensorType(collect(Int, ()), MLIR.IR.Type(Int64)) # end # +# TODO if want to use this, need to somehow add TracedRequest to TracedTypes, which is currently const # Base.@nospecializeinfer function Reactant.make_tracer( # seen, # @nospecialize(prev::TracedRequest), diff --git a/src/Reactant.jl b/src/Reactant.jl index ca53005fa4..9e38f92347 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -233,8 +233,7 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") -# const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} -TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} +const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} include("ControlFlow.jl") include("Tracing.jl") From 5ee8762ae3282e4f4a3d5f9d036d7e910994e855 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 18 Sep 2025 14:12:42 -0500 Subject: [PATCH 76/97] Set size of status to max necessary hopefully --- ext/ReactantMPIExt/Ops.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 44dba018fa..0d6f2bd7c2 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -595,10 +595,11 @@ function wait( #! format: off # int MPI_Wait(MPI_Request* request, MPI_Status* status) + # Size of status is implem dependent, we try to set it to the max IR.inject!(sym_name, """ func.func @$sym_name(%errcode : !llvm.ptr, %req : !llvm.ptr) -> () { %c1_i32 = arith.constant 1 : i32 - %status = llvm.alloca %c1_i32 x !llvm.array<4 x i32> : (i32) -> !llvm.ptr + %status = llvm.alloca %c1_i32 x !llvm.array<6 x i32> : (i32) -> !llvm.ptr %res = llvm.call @MPI_Wait(%req, %status) : (!llvm.ptr, !llvm.ptr) -> (i32) llvm.store %res, %errcode : i32, !llvm.ptr func.return From 4c9e2ac1ab335cbe8c986e184cd998f51d10e909 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 18 Sep 2025 14:21:46 -0500 Subject: [PATCH 77/97] Add Isend/Irecv/Wait unit test --- test/integration/mpi.jl | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index 79aed4b04c..1ce78c386f 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -1,7 +1,7 @@ using Test, MPI, Reactant -# # MPI only works on cpu currently --- is this the right way/place to enforce that? -# Reactant.set_default_backend("cpu") +# MPI only works on cpu currently --- is this the right way/place to enforce that? +Reactant.set_default_backend("cpu") MPI.Init() @@ -30,6 +30,7 @@ end nranks = MPI.Comm_size(comm) # test MPI.jl Send / Reactant Recv + # useful to isolate Reactant issues @testset "MPI.jl Send / Reactant Recv!" begin send_buf = ones(5) tag = 43 @@ -44,6 +45,7 @@ end end # test Reactant Send / MPI.jl Recv + # useful to isolate Reactant issues @testset "Reactant Send / MPI.jl Recv!" begin send_buf = ConcreteRArray(ones(5)) tag = 43 @@ -93,4 +95,30 @@ end end +@testset "Isend, Irecv!, Wait" begin + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + nranks = MPI.Comm_size(comm) + + # note: currently don't allow a request to cross the compile boundary + send_buf = ConcreteRArray(ones(5)) + recv_buf = ConcreteRArray(zeros(5)) + tag = 42 + function isendirecvwait(send_buf, recv_buf, rank, tag, comm) + if rank==0 + dest = 1 + req = MPI.Isend(send_buf, dest, tag, comm) + MPI.Wait(req) + return nothing + elseif rank==1 + src = 0 + req = MPI.Irecv!(recv_buf, src, tag, comm) + MPI.Wait(req) + return recv_buf + end + end + @jit isendirecvwait(send_buf, recv_buf, rank, tag, comm) + rank==1 && @test recv_buf == send_buf +end + MPI.Finalize() From 81d2d5867bd877453eaa83384e1d0e0c59e497d1 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 18 Sep 2025 14:33:04 -0500 Subject: [PATCH 78/97] Cleanup return vals in tests --- test/integration/mpi.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index 1ce78c386f..caade5f2ff 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -24,7 +24,7 @@ end @test nranks == @jit MPI.Allreduce(x, MPI.SUM, MPI.COMM_WORLD) end -@testset "Send, Recv!" begin +@testset "Send / Recv!" begin comm = MPI.COMM_WORLD rank = MPI.Comm_rank(comm) nranks = MPI.Comm_size(comm) @@ -81,12 +81,12 @@ end function sendrecv!(comm, rank, send_buf, recv_buf, tag) if rank == 0 dest = 1 - err_code = MPI.Send(send_buf, dest, tag, comm) # kinda hacky, but unfort have to return something otherwise julia optimizes this out @code_lowered - return err_code + MPI.Send(send_buf, dest, tag, comm) + return nothing elseif rank == 1 src = 0 MPI.Recv!(recv_buf, src, tag, comm) - return recv_buf + return nothing end end @jit sendrecv!(comm, rank, send_buf, recv_buf, tag) @@ -95,7 +95,7 @@ end end -@testset "Isend, Irecv!, Wait" begin +@testset "Isend / Irecv! / Wait" begin comm = MPI.COMM_WORLD rank = MPI.Comm_rank(comm) nranks = MPI.Comm_size(comm) @@ -114,7 +114,7 @@ end src = 0 req = MPI.Irecv!(recv_buf, src, tag, comm) MPI.Wait(req) - return recv_buf + return nothing end end @jit isendirecvwait(send_buf, recv_buf, rank, tag, comm) From bf35523ebccaeb5cb73a64a913d81ef32223b530 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 18 Sep 2025 14:59:37 -0500 Subject: [PATCH 79/97] No need to retun error codes from Send/Recv! now that we have this is_pure stuff --- ext/ReactantMPIExt/Ops.jl | 78 +++++++++------------------------ ext/ReactantMPIExt/Overrides.jl | 2 - 2 files changed, 20 insertions(+), 60 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 0d6f2bd7c2..129490e87d 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -319,45 +319,35 @@ function send( "llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32", ) + # int MPI_Send(const void* buf, int count, MPI_Datatype datatype, + # int dest, int tag, MPI_Comm comm) #! format: off - # int MPI_Send(const void* buf, int count, MPI_Datatype datatype, - # int dest, int tag, MPI_Comm comm) IR.inject!(sym_name, """ - func.func @$sym_name(%errcode : !llvm.ptr, %buf : !llvm.ptr, %count_ptr : !llvm.ptr, %dest_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { + func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %dest_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr %datatype = llvm.mlir.addressof @$(mpi_datatype_name) : !llvm.ptr %count = llvm.load %count_ptr : !llvm.ptr -> i32 %dest = llvm.load %dest_ptr : !llvm.ptr -> i32 %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 - %res = llvm.call @MPI_Send(%buf, %count, %datatype, %dest, %tag, %comm) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> (i32) - llvm.store %res, %errcode : i32, !llvm.ptr + llvm.call @MPI_Send(%buf, %count, %datatype, %dest, %tag, %comm) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> (i32) func.return } """) #! format: on count = Reactant.Ops.constant(Int32(length(buf))) - errcode = Reactant.Ops.constant(fill(Cint(0))) - - output_operand_aliases = IR.Attribute([ - IR.Attribute( - MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL - ), - ), - ]) - ret = enzymexla.jit_call( + enzymexla.jit_call( IR.Value[ - errcode.mlir_data, buf.mlir_data, count.mlir_data, dest.mlir_data, tag.mlir_data + buf.mlir_data, count.mlir_data, dest.mlir_data, tag.mlir_data ]; fn=sym_attr, - result_0=IR.Type[mlir_type(errcode)], - output_operand_aliases=output_operand_aliases, + result_0=IR.Type[], + output_operand_aliases=IR.Attribute(IR.Attribute[]), location, ) - errcode.mlir_data = IR.result(ret) - return errcode + + return nothing end # TODO need c-function for creating MLIR `mpi.request` type? @@ -417,9 +407,8 @@ function isend( location, ) - # return TracedRNumber request.mlir_data = IR.result(ret) - return request + return request # we return a TracedRNumber, converted to TracedRequest in Overrides.jl end function recv!( @@ -446,75 +435,48 @@ function recv!( #! format: off IR.inject!(sym_name, """ - func.func @$sym_name(%errcode : !llvm.ptr, %buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { + func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr %datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr %status = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr %count = llvm.load %count_ptr : !llvm.ptr -> i32 %source = llvm.load %source_ptr : !llvm.ptr -> i32 %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 - %res = llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32) - llvm.store %res, %errcode : i32, !llvm.ptr + llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32) func.return } """) #! format: on count = Reactant.Ops.constant(Int32(length(recvbuf))) - errcode = Reactant.Ops.constant(fill(Cint(0))) output_operand_aliases = IR.Attribute([ IR.Attribute( MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), 1, Ref{Int64}(0), 0, 0, C_NULL - ), - ), - IR.Attribute( - MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), 1, Ref{Int64}(1), 1, 0, C_NULL + MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL ), ), ]) ret = enzymexla.jit_call( IR.Value[ - errcode.mlir_data, recvbuf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data, ]; fn=sym_attr, - result_0=[mlir_type(errcode), mlir_type(recvbuf)], + result_0=[mlir_type(recvbuf)], output_operand_aliases, location, ) - errcode.mlir_data = IR.result(ret, 1) - recvbuf.mlir_data = IR.result(ret, 2) + recvbuf.mlir_data = IR.result(ret) - # TODO is returning recvbuf the best choice here stylistically? - return errcode, recvbuf + return recvbuf end -# # TODO need c-function for creating MLIR `mpi.request` type? -# function irecv!( -# ref::TracedRArray, -# tag::TracedRNumber, -# src::TracedRNumber; -# location=mlir_stacktrace("mpi.irecv", @__FILE__, @__LINE__), -# ) -# # return TracedRequest( -# # MLIR.IR.result(mpi.irecv(ref.mlir_data, tag.mlir_data, src.mlir_data; location)) -# # ) -# inputs = IR.Value[ref.mlir_data, tag.mlir_data, src.mlir_data] -# sym = IR.FlatSymbolRefAttribute("enzymexla_wrapper_MPI_Irecv") -# rettype = IR.Type[] -# -# IR.result(enzymexla.jit_call(inputs; fn=sym, result_0=rettype, location)) -# return ref -# end - +# TODO need c-function for creating MLIR `mpi.request` type? function irecv!( buf::TracedRArray, tag::TracedRNumber, @@ -593,9 +555,9 @@ function wait( IR.inject!("MPI_Wait", "llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32") + # NOTE: Size of status is implem dependent, we try to set it to the max + # int MPI_Wait(MPI_Request* request, MPI_Status* status) #! format: off - # int MPI_Wait(MPI_Request* request, MPI_Status* status) - # Size of status is implem dependent, we try to set it to the max IR.inject!(sym_name, """ func.func @$sym_name(%errcode : !llvm.ptr, %req : !llvm.ptr) -> () { %c1_i32 = arith.constant 1 : i32 diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 535ab617cd..5b08d81f28 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -44,8 +44,6 @@ function MPI.Send( end # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` -# NOTE unlike MPI.jl's `MPI.Send`, we return the errcode to generate the data dep -# that prevents it from being optimized away function MPI.Send( buf::TracedRArray, dest::TracedRNumber, From 8a8dcf708d6395188fde962811f64a66d54a4c3a Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 18 Sep 2025 15:02:07 -0500 Subject: [PATCH 80/97] No need to return errorcode from wait either --- ext/ReactantMPIExt/Ops.jl | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 129490e87d..e56cd3476e 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -559,36 +559,24 @@ function wait( # int MPI_Wait(MPI_Request* request, MPI_Status* status) #! format: off IR.inject!(sym_name, """ - func.func @$sym_name(%errcode : !llvm.ptr, %req : !llvm.ptr) -> () { + func.func @$sym_name(%req : !llvm.ptr) -> () { %c1_i32 = arith.constant 1 : i32 %status = llvm.alloca %c1_i32 x !llvm.array<6 x i32> : (i32) -> !llvm.ptr - %res = llvm.call @MPI_Wait(%req, %status) : (!llvm.ptr, !llvm.ptr) -> (i32) - llvm.store %res, %errcode : i32, !llvm.ptr + llvm.call @MPI_Wait(%req, %status) : (!llvm.ptr, !llvm.ptr) -> (i32) func.return } """) #! format: on - errcode = Reactant.Ops.constant(fill(Cint(0))) - - output_operand_aliases = IR.Attribute([ - IR.Attribute( - MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL - ), - ), - ]) - - ret = enzymexla.jit_call( - IR.Value[errcode.mlir_data, req.mlir_data]; + enzymexla.jit_call( + IR.Value[req.mlir_data]; fn=sym_attr, - result_0=IR.Type[mlir_type(errcode)], + result_0=IR.Type[], location, - output_operand_aliases=output_operand_aliases, + output_operand_aliases=IR.Attribute(IR.Attribute[]), ) - errcode.mlir_data = IR.result(ret) - return errcode + return nothing end function inject_mpi_op!(op) From 10248889c8cbc6ef6c15e475503ed2c8e26ed0f6 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 18 Sep 2025 15:09:01 -0500 Subject: [PATCH 81/97] Remove set_default_backend("cpu") from mpi.jl - Not sure that's the right way/place to do it --- test/integration/mpi.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index caade5f2ff..3f70ab93d3 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -1,7 +1,7 @@ using Test, MPI, Reactant -# MPI only works on cpu currently --- is this the right way/place to enforce that? -Reactant.set_default_backend("cpu") +# # MPI only works on cpu currently --- is this the right way/place to enforce that? +# Reactant.set_default_backend("cpu") MPI.Init() From 0a1258d0aa467972aaf0368c461ebd066c658a73 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 18 Sep 2025 16:09:44 -0500 Subject: [PATCH 82/97] Use tryinject instead of inject in wait for MPI_COMM_WORLD --- ext/ReactantMPIExt/Ops.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index e56cd3476e..f448918eb3 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -549,9 +549,8 @@ function wait( sym_name = "enzymexla_wrapper_MPI_Wait" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - # # TODO Temporarily commented out bc otherwise can't compile together with any other - # # func that tries to inject the same thing (e.g., isend, irecv) - # IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + # likely isend/irecv will have injected MPI_COMM_WORLD already + IR.tryinject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") IR.inject!("MPI_Wait", "llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32") From c017627c987602fcec9d6963a2a0980f0ed42b29 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 18 Sep 2025 16:13:09 -0500 Subject: [PATCH 83/97] UPdate tests --- roman-temp-debug/2025.09.mpi/launcher.jl | 1 + .../2025.09.mpi/test-isend-irecv.jl | 100 +++++++++++------- .../2025.09.mpi/test-send-recv2.jl | 41 +++---- 3 files changed, 79 insertions(+), 63 deletions(-) diff --git a/roman-temp-debug/2025.09.mpi/launcher.jl b/roman-temp-debug/2025.09.mpi/launcher.jl index 7cd95dddf9..c25f81d633 100644 --- a/roman-temp-debug/2025.09.mpi/launcher.jl +++ b/roman-temp-debug/2025.09.mpi/launcher.jl @@ -1,4 +1,5 @@ # launcher.jl # usage e.g.: julia launcher.jl 1 test-isend-irecv.jl using MPI +println(MPI.identify_implementation()) run(`$(MPI.mpiexec()) -n $(ARGS[1]) julia --project $(ARGS[2])`) diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl index 5f39bd54fd..ee73d5d531 100644 --- a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl +++ b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl @@ -11,14 +11,14 @@ nranks = MPI.Comm_size(comm) # # -------------------------- -# # test MPI.jl Isend / Irecv! # # -------------------------- +# # test MPI.jl Isend / Irecv! # # Skip test if not enough processes # if nranks < 2 # @error "Need at least 2 MPI processes for Isend/Irecv test" # end -# send_buf = [1, 2, 3, 4, 5] +# # send_buf = [1, 2, 3, 4, 5] # send_buf = ones(5) # recv_buf = zeros(5) # tag = 42 @@ -28,17 +28,56 @@ nranks = MPI.Comm_size(comm) # println("Rank 0: Waiting...") # MPI.Wait(req_send) # println("Rank 0: Sent") +# elseif rank == 1 + # source = 0 + # req_recv = MPI.Irecv!(recv_buf, source, tag, comm) + # println("Rank 1: Waiting...") + # status = MPI.Wait(req_recv) + # println( "Rank 1: Received: $(recv_buf == send_buf)" ) + # # @test MPI.Get_source(status) == 0 + # # @test MPI.Get_tag(status) == 42 +# end + +# # Send / Irecv+wait +# send_buf = ones(5) +# recv_buf = zeros(5) +# tag = 43 +# if rank == 0 +# dest = 1 +# MPI.Send(send_buf, dest, tag, comm) # elseif rank == 1 # source = 0 + +# # MPI.Recv!(recv_buf, source, tag, comm) + # req_recv = MPI.Irecv!(recv_buf, source, tag, comm) -# println("Rank 1: Waiting...") # status = MPI.Wait(req_recv) -# println( "Rank 1: Received: $(recv_buf == send_buf)" ) -# # @test MPI.Get_source(status) == 0 -# # @test MPI.Get_tag(status) == 42 + +# println(recv_buf == send_buf) # end -# # -------------------------- +# # Isend+wait / Recv +# send_buf = ones(5) +# recv_buf = zeros(5) +# tag = 43 +# if rank == 0 +# dest = 1 + +# # MPI.Send(send_buf, dest, tag, comm) + +# req_send = MPI.Isend(send_buf, dest, tag, comm) +# MPI.Wait(req_send) +# elseif rank == 1 +# source = 0 + +# MPI.Recv!(recv_buf, source, tag, comm) + +# # req_recv = MPI.Irecv!(recv_buf, source, tag, comm) +# # status = MPI.Wait(req_recv) + +# println(recv_buf == send_buf) +# end +# # -------------------------- # send_buf = ConcreteRArray(ones(5)) @@ -81,34 +120,23 @@ nranks = MPI.Comm_size(comm) # # println("\nllvm:\n", @code_llvm bbb(recv_buf, src, tag, comm)) -# send_buf = ConcreteRArray(ones(5)) -# recv_buf = ConcreteRArray(zeros(5)) -# tag = 42 -# function aaa(send_buf, recv_buf, rank, tag, comm) -# if rank==0 -# dest = 1 -# req = MPI.Isend(send_buf, dest, tag, comm) -# # errcode = MPI.Wait(req) -# # return errcode -# return nothing -# elseif rank==1 -# src = 1 -# req = MPI.Irecv!(recv_buf, src, tag, comm) -# errcode = MPI.Wait(req) -# return errcode, recv_buf -# end -# end -# @jit aaa(send_buf, recv_buf, rank, tag, comm) -# println("Rank $rank returned") - -# # rank==1 && sleep(10) -# # println("\n$rank: code_hlo optimize=false:\n", @code_hlo optimize=false aaa(send_buf, recv_buf, rank, tag, comm)) -# # println("\n$rank: code_hlo:\n", @code_hlo aaa(send_buf, recv_buf, rank, tag, comm)) -# # println("\n$rank: code_xla:\n", @code_xla aaa(send_buf, recv_buf, rank, tag, comm)) -# # bbb = @compile aaa(send_buf, recv_buf, rank, tag, comm) -# # println("\n$rank: lowered:\n", @code_lowered bbb(send_buf, recv_buf, rank, tag, comm)) -# # println("\n$rank: typed:\n", @code_typed bbb(send_buf, recv_buf, rank, tag, comm)) -# # println("\n$rank: llvm:\n", @code_llvm bbb(send_buf, recv_buf, rank, tag, comm)) - +send_buf = ConcreteRArray(ones(5)) +recv_buf = ConcreteRArray(zeros(5)) +tag = 42 +function aaa(send_buf, recv_buf, rank, tag, comm) + if rank==0 + dest = 1 + req = MPI.Isend(send_buf, dest, tag, comm) + MPI.Wait(req) + return nothing + elseif rank==1 + src = 0 + req = MPI.Irecv!(recv_buf, src, tag, comm) + MPI.Wait(req) + return nothing + end +end +@jit aaa(send_buf, recv_buf, rank, tag, comm) +println("Rank $rank returned, $(recv_buf==send_buf)") MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/test-send-recv2.jl b/roman-temp-debug/2025.09.mpi/test-send-recv2.jl index edf9da9df6..975996c208 100644 --- a/roman-temp-debug/2025.09.mpi/test-send-recv2.jl +++ b/roman-temp-debug/2025.09.mpi/test-send-recv2.jl @@ -20,39 +20,26 @@ function aaa(comm, rank, send_buf, recv_buf, tag) if rank == 0 dest = 1 # ccall(:jl_breakpoint, Cvoid, (Any,), dest) - return MPI.Send(send_buf, dest, tag, comm) # kinda hacky, but have to return this otherwise julia optimizes this out + MPI.Send(send_buf, dest, tag, comm) # kinda hacky, but have to return this otherwise julia optimizes this out + return nothing elseif rank == 1 src = 0 - # return MPI.Recv!(recv_buf, src, tag, comm) MPI.Recv!(recv_buf, src, tag, comm) - return recv_buf + return nothing end end -result = @jit aaa(comm, rank, send_buf, recv_buf, tag) - -if rank==0 - println("Rank $rank: $result") -elseif rank==1 - println("Rank $rank: $(result[2])") - println( recv_buf == send_buf ) -end - -# # rank==1 && sleep(5) -# # # println("\nRank $rank:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) -# # println("\nRank $rank:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) - -# if rank==0 -# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) -# println("\ncode_hlo:\n", @code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) -# println("\ncode_xla:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) - -# bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) -# println("\nlowered:\n", @code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) -# println("\ntyped:\n", @code_typed bbb(comm, rank, send_buf, recv_buf, tag)) -# println("\nllvm:\n", @code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) -# end - +@jit aaa(comm, rank, send_buf, recv_buf, tag) +println("Rank $rank returned, $(recv_buf==send_buf)") + +# rank==1 && sleep(10) +# println("\n$rank: code_hlo optimize=false:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) +# println("\n$rank: code_hlo:\n", @code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) +# println("\n$rank: code_xla:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) +# bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) +# println("\n$rank: lowered:\n", @code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) +# println("\n$rank: typed:\n", @code_typed bbb(comm, rank, send_buf, recv_buf, tag)) +# println("\n$rank: llvm:\n", @code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) # ---------------- From bc8f9c4fd9ef8b7f2322dc3eb2658a290ffa8fc6 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Thu, 18 Sep 2025 16:15:42 -0500 Subject: [PATCH 84/97] Delete debug files --- roman-temp-debug/2025.09.mpi/Project.toml | 4 - roman-temp-debug/2025.09.mpi/launcher.jl | 5 - roman-temp-debug/2025.09.mpi/runtests.sh | 22 -- roman-temp-debug/2025.09.mpi/sergio.jl | 29 --- roman-temp-debug/2025.09.mpi/setup.sh | 7 - .../2025.09.mpi/test-isend-irecv.jl | 142 ------------ .../2025.09.mpi/test-send-recv.jl | 210 ------------------ .../2025.09.mpi/test-send-recv2.jl | 87 -------- roman-temp-debug/Project.toml | 4 - roman-temp-debug/README.md | 1 - roman-temp-debug/bbb.jl | 14 -- roman-temp-debug/runtests.sh | 22 -- roman-temp-debug/setup.sh | 7 - 13 files changed, 554 deletions(-) delete mode 100644 roman-temp-debug/2025.09.mpi/Project.toml delete mode 100644 roman-temp-debug/2025.09.mpi/launcher.jl delete mode 100755 roman-temp-debug/2025.09.mpi/runtests.sh delete mode 100644 roman-temp-debug/2025.09.mpi/sergio.jl delete mode 100644 roman-temp-debug/2025.09.mpi/setup.sh delete mode 100644 roman-temp-debug/2025.09.mpi/test-isend-irecv.jl delete mode 100644 roman-temp-debug/2025.09.mpi/test-send-recv.jl delete mode 100644 roman-temp-debug/2025.09.mpi/test-send-recv2.jl delete mode 100644 roman-temp-debug/Project.toml delete mode 100644 roman-temp-debug/README.md delete mode 100644 roman-temp-debug/bbb.jl delete mode 100755 roman-temp-debug/runtests.sh delete mode 100644 roman-temp-debug/setup.sh diff --git a/roman-temp-debug/2025.09.mpi/Project.toml b/roman-temp-debug/2025.09.mpi/Project.toml deleted file mode 100644 index 31b989ee28..0000000000 --- a/roman-temp-debug/2025.09.mpi/Project.toml +++ /dev/null @@ -1,4 +0,0 @@ -[deps] -MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/roman-temp-debug/2025.09.mpi/launcher.jl b/roman-temp-debug/2025.09.mpi/launcher.jl deleted file mode 100644 index c25f81d633..0000000000 --- a/roman-temp-debug/2025.09.mpi/launcher.jl +++ /dev/null @@ -1,5 +0,0 @@ -# launcher.jl -# usage e.g.: julia launcher.jl 1 test-isend-irecv.jl -using MPI -println(MPI.identify_implementation()) -run(`$(MPI.mpiexec()) -n $(ARGS[1]) julia --project $(ARGS[2])`) diff --git a/roman-temp-debug/2025.09.mpi/runtests.sh b/roman-temp-debug/2025.09.mpi/runtests.sh deleted file mode 100755 index c71f11f386..0000000000 --- a/roman-temp-debug/2025.09.mpi/runtests.sh +++ /dev/null @@ -1,22 +0,0 @@ -# ------------------- -# perlmutter -# ------------------- -salloc --nodes 1 --qos interactive --time 04:00:00 --constraint gpu --gpus 4 --account=nstaff - -# Flags from https://github.com/PRONTOLab/GB-25/blob/main/sharding/perlmutter_scaling_test.jl -export JULIA_CUDA_MEMORY_POOL=none -export JULIA_CUDA_USE_COMPAT=false - -# Flag from: https://github.com/PRONTOLab/GB-25/blob/main/sharding/common_submission_generator.jl -export XLA_REACTANT_GPU_MEM_FRACTION=0.9 - -srun -n 2 julia --project ./mpi.jl - -# Then added this flag to srun -srun -n 2 --gpus-per-task=1 julia --project ./mpi.jl - - -# ------------------- -# local laptop -# ------------------- -mpiexec -n 2 julia --project mpi.jl diff --git a/roman-temp-debug/2025.09.mpi/sergio.jl b/roman-temp-debug/2025.09.mpi/sergio.jl deleted file mode 100644 index a044f5284f..0000000000 --- a/roman-temp-debug/2025.09.mpi/sergio.jl +++ /dev/null @@ -1,29 +0,0 @@ -using Reactant -using MPI -using Libdl - -Reactant.set_default_backend("cpu") - -tag = 43 -comm = MPI.COMM_WORLD -source = 1 - -println("Here we go!") - -MPI.Init() - -if MPI.Comm_rank(MPI.COMM_WORLD) == 0 - buffer = Reactant.to_rarray(zeros(Int32, 8)) - println("[$(MPI.Comm_rank(MPI.COMM_WORLD))] before - $buffer") - @jit MPI.Recv!(buffer, source, tag, comm) - println("[$(MPI.Comm_rank(MPI.COMM_WORLD))] after - $buffer") - println(isapprox(buffer, ones(8))) -else - buffer = ones(Int32, 8) - destination = 0 - println("[$(MPI.Comm_rank(MPI.COMM_WORLD))] sending - $buffer") - MPI.Send(buffer, destination, tag, comm) - println("[$(MPI.Comm_rank(MPI.COMM_WORLD))] sent!") -end - -MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/setup.sh b/roman-temp-debug/2025.09.mpi/setup.sh deleted file mode 100644 index 7dd7d6d268..0000000000 --- a/roman-temp-debug/2025.09.mpi/setup.sh +++ /dev/null @@ -1,7 +0,0 @@ -# how I set up a julia project in this directory -# These commands create Project.toml and Manifest.toml -julia ] -activate . -dev /global/homes/r/romanlee/Documents/codes/Reactant.jl -add MPI -add Test diff --git a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl b/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl deleted file mode 100644 index ee73d5d531..0000000000 --- a/roman-temp-debug/2025.09.mpi/test-isend-irecv.jl +++ /dev/null @@ -1,142 +0,0 @@ -using Test, MPI, Reactant, InteractiveUtils - -Reactant.set_default_backend("cpu") -# Reactant.set_default_backend("gpu") - -MPI.Init() - -comm = MPI.COMM_WORLD -rank = MPI.Comm_rank(comm) -nranks = MPI.Comm_size(comm) - - -# # -------------------------- -# # -------------------------- -# # test MPI.jl Isend / Irecv! -# # Skip test if not enough processes -# if nranks < 2 -# @error "Need at least 2 MPI processes for Isend/Irecv test" -# end - -# # send_buf = [1, 2, 3, 4, 5] -# send_buf = ones(5) -# recv_buf = zeros(5) -# tag = 42 -# if rank == 0 -# dest = 1 -# req_send = MPI.Isend(send_buf, dest, tag, comm) -# println("Rank 0: Waiting...") -# MPI.Wait(req_send) -# println("Rank 0: Sent") -# elseif rank == 1 - # source = 0 - # req_recv = MPI.Irecv!(recv_buf, source, tag, comm) - # println("Rank 1: Waiting...") - # status = MPI.Wait(req_recv) - # println( "Rank 1: Received: $(recv_buf == send_buf)" ) - # # @test MPI.Get_source(status) == 0 - # # @test MPI.Get_tag(status) == 42 -# end - -# # Send / Irecv+wait -# send_buf = ones(5) -# recv_buf = zeros(5) -# tag = 43 -# if rank == 0 -# dest = 1 -# MPI.Send(send_buf, dest, tag, comm) -# elseif rank == 1 -# source = 0 - -# # MPI.Recv!(recv_buf, source, tag, comm) - -# req_recv = MPI.Irecv!(recv_buf, source, tag, comm) -# status = MPI.Wait(req_recv) - -# println(recv_buf == send_buf) -# end - -# # Isend+wait / Recv -# send_buf = ones(5) -# recv_buf = zeros(5) -# tag = 43 -# if rank == 0 -# dest = 1 - -# # MPI.Send(send_buf, dest, tag, comm) - -# req_send = MPI.Isend(send_buf, dest, tag, comm) -# MPI.Wait(req_send) -# elseif rank == 1 -# source = 0 - -# MPI.Recv!(recv_buf, source, tag, comm) - -# # req_recv = MPI.Irecv!(recv_buf, source, tag, comm) -# # status = MPI.Wait(req_recv) - -# println(recv_buf == send_buf) -# end -# # -------------------------- - - -# send_buf = ConcreteRArray(ones(5)) -# tag = 42 -# dest = 1 -# function aaa(send_buf, dest, tag, comm) -# req = MPI.Isend(send_buf, dest, tag, comm) -# errcode = MPI.Wait(req) -# return errcode -# end -# @jit aaa(send_buf, dest, tag, comm) - -# rank==1 && sleep(10) -# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(send_buf, dest, tag, comm)) -# println("\ncode_hlo:\n", @code_hlo aaa(send_buf, dest, tag, comm)) -# println("\ncode_xla:\n", @code_xla aaa(send_buf, dest, tag, comm)) -# bbb = @compile aaa(send_buf, dest, tag, comm) -# println("\nlowered:\n", @code_lowered bbb(send_buf, dest, tag, comm)) -# # println("\ntyped:\n", @code_typed bbb(send_buf, dest, tag, comm)) -# # println("\nllvm:\n", @code_llvm bbb(send_buf, dest, tag, comm)) - - -# recv_buf = ConcreteRArray(zeros(5)) -# tag = 42 -# src = 1 -# function aaa(recv_buf, src, tag, comm) -# req = MPI.Irecv!(recv_buf, src, tag, comm) -# errcode = MPI.Wait(req) -# return errcode, recv_buf -# end -# @jit aaa(recv_buf, src, tag, comm) - -# rank==1 && sleep(10) -# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false aaa(recv_buf, src, tag, comm)) -# println("\ncode_hlo:\n", @code_hlo aaa(recv_buf, src, tag, comm)) -# println("\ncode_xla:\n", @code_xla aaa(recv_buf, src, tag, comm)) -# bbb = @compile aaa(recv_buf, src, tag, comm) -# println("\nlowered:\n", @code_lowered bbb(recv_buf, src, tag, comm)) -# # println("\ntyped:\n", @code_typed bbb(recv_buf, src, tag, comm)) -# # println("\nllvm:\n", @code_llvm bbb(recv_buf, src, tag, comm)) - - -send_buf = ConcreteRArray(ones(5)) -recv_buf = ConcreteRArray(zeros(5)) -tag = 42 -function aaa(send_buf, recv_buf, rank, tag, comm) - if rank==0 - dest = 1 - req = MPI.Isend(send_buf, dest, tag, comm) - MPI.Wait(req) - return nothing - elseif rank==1 - src = 0 - req = MPI.Irecv!(recv_buf, src, tag, comm) - MPI.Wait(req) - return nothing - end -end -@jit aaa(send_buf, recv_buf, rank, tag, comm) -println("Rank $rank returned, $(recv_buf==send_buf)") - -MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/test-send-recv.jl b/roman-temp-debug/2025.09.mpi/test-send-recv.jl deleted file mode 100644 index bdbec587b2..0000000000 --- a/roman-temp-debug/2025.09.mpi/test-send-recv.jl +++ /dev/null @@ -1,210 +0,0 @@ -using Test, MPI, Reactant - -Reactant.set_default_backend("cpu") -# Reactant.set_default_backend("gpu") - -MPI.Init() - -# println(@code_hlo optimize=false MPI.Comm_rank(MPI.COMM_WORLD)) -# println(@code_hlo optimize=true MPI.Comm_rank(MPI.COMM_WORLD)) - -# pass on cpu -# fail on gpu: segfault when trying to return res in Ops.jl comm_rank -@testset "Comm_rank" begin - comm = MPI.COMM_WORLD - rank = MPI.Comm_rank(comm) - @test rank == @jit MPI.Comm_rank(comm) -end - -# pass on cpu -# fail on gpu: segfaulta upon trying to return res in Ops.jl comm_size -@testset "Comm_size" begin - comm = MPI.COMM_WORLD - nranks = MPI.Comm_size(comm) - @test nranks == @jit MPI.Comm_size(comm) -end - -@testset "Allreduce" begin - comm = MPI.COMM_WORLD - nranks = MPI.Comm_size(comm) - - # test good-ol MPI.jl allreduce - @test nranks == MPI.Allreduce(1, MPI.SUM, MPI.COMM_WORLD) - - # pass on cpu - # pass on gpu! - # test Reactant allreduce - @test nranks == @jit MPI.Allreduce(1, MPI.SUM, MPI.COMM_WORLD) -end - -@testset "Send, Recv!" begin - - comm = MPI.COMM_WORLD - rank = MPI.Comm_rank(comm) - nranks = MPI.Comm_size(comm) - - # "Need at least 2 MPI processes for send tests" - if nranks < 2 - @warn "need more than 2 mpi ranks, skipping" - return - end - - # test MPI.jl Send/Recv - @testset "MPI.jl Send / Recv!" begin - send_buf = fill(1) - tag = 43 - if rank == 0 - MPI.Send(send_buf, comm; dest=1, tag=tag) - @test true # Send completed - elseif rank == 1 - recv_buf = fill(12) - MPI.Recv!(recv_buf, comm; source=0, tag=tag) - @test recv_buf == send_buf - end - end - - # test MPI.jl Send / Reactant Recv - @testset "MPI.jl Send / Reactant Recv!" begin - send_buf = fill(1) - tag = 43 - if rank == 0 - MPI.Send(send_buf, comm; dest=1, tag=tag) - @test true - elseif rank == 1 - recv_buf = ConcreteRArray(fill(12)) - source = 0 - @jit MPI.Recv!(recv_buf, source, tag, comm) - @test recv_buf == send_buf - end - end - - # test Reactant Send / MPI.jl Recv - @testset "Reactant Send / MPI.jl Recv!" begin - send_buf = ConcreteRArray(fill(1)) - tag = 43 - if rank == 0 - dest = 1 - @jit MPI.Send(send_buf, dest, tag, comm) - @test true - elseif rank == 1 - recv_buf = fill(12) - MPI.Recv!(recv_buf, comm; source=0, tag=tag) - @test recv_buf == send_buf - end - end - - # test Reactant Send/Recv - @testset "Reactant Send / Recv!" begin - send_buf = ConcreteRArray(fill(1)) - tag = 43 - if rank == 0 - # Send: pass on cpu, pass on gpu - dest = 1 - @jit MPI.Send(send_buf, dest, tag, comm) - @test true # Send completed - elseif rank == 1 - # hang on cpu - # segfault on gpu upon trying to reference res - recv_buf = ConcreteRArray(fill(12)) - src = 0 - @jit MPI.Recv!(recv_buf, src, tag, comm) - @test recv_buf == send_buf - end - end -end - -# ---------- -# debug -# ---------- -# comm = MPI.COMM_WORLD -# rank = MPI.Comm_rank(comm) -# nranks = MPI.Comm_size(comm) - -# send_buf = ConcreteRArray(fill(1)) -# tag = 43 -# dest = 1 -# # @jit dbSend(send_buf, dest, tag, comm) -# # @jit MPI.Senddd(send_buf, dest, tag, comm) -# # @jit Senddd(send_buf, dest, tag, comm) -# @jit func_foo() - -# if nranks < 2 -# @error "Need at least 2 MPI processes for send tests. Skipping." -# end - -# # test Reactant Send/Recv -# send_buf = ConcreteRArray(fill(1)) -# tag = 43 -# if rank == 0 -# # Send: pass on cpu, pass on gpu -# # dest = 1 -# dest = 1 -# # @jit MPI.Send(send_buf, dest, tag, comm) -# elseif rank == 1 -# # # hang on cpu -# # # segfault on gpu upon trying to reference res -# # recv_buf = ConcreteRArray(fill(12)) -# # src = 0 -# # @jit MPI.Recv!(recv_buf, src, tag, comm) -# end - - - - -# # # test Reactant Send/Recv -# # send_buf = ConcreteRArray(fill(1)) -# # if rank == 0 -# # # Send: pass on cpu, pass on gpu -# # @jit MPI.Send(send_buf, 1, 0, comm) - -# # dest = 12 -# # tag = 33 -# # println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) - -# # # @test true # Send completed -# # elseif rank == 1 -# # # # hang on cpu -# # # # segfault on gpu upon trying to reference res -# # # recv_buf = ConcreteRArray(fill(12)) -# # # # @jit MPI.Recv!(recv_buf, 0, 0, comm) -# # # source = 12 -# # # tag = 35 -# # # println(@code_hlo optimize=false MPI.Recv!(recv_buf, source, tag, comm)) -# # # # @test recv_buf == send_buf - -# # # # println(@code_hlo MPI.Recv!(recv_buf, 0, 0, comm)) -# # end - -# send_buf = ConcreteRArray(fill(1)) -# tag = 43 -# if rank == 0 -# dest = 3333 - -# println("@code_hlo optimize=false:") -# println(@code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) -# println("") - -# # println("@code_hlo optimize=:before_jit:") -# # println(@code_hlo optimize=:before_jit MPI.Send(send_buf, dest, tag, comm)) -# # println("") - -# # println("@jit MPI.Send:") -# # @jit MPI.Send(send_buf, dest, tag, comm) - -# elseif rank == 1 -# # recv_buf = ConcreteRArray(fill(12)) -# # source = 0 - -# # println("code hlo:") -# # println(@code_hlo optimize=false MPI.Recv!(recv_buf, source, tag, comm)) -# # println("") - -# # println("@jit MPI.Recv!:") -# # @jit MPI.Recv!(recv_buf, source, tag, comm) - -# # # # println("after ", recv_buf==send_buf) -# end - - - -MPI.Finalize() diff --git a/roman-temp-debug/2025.09.mpi/test-send-recv2.jl b/roman-temp-debug/2025.09.mpi/test-send-recv2.jl deleted file mode 100644 index 975996c208..0000000000 --- a/roman-temp-debug/2025.09.mpi/test-send-recv2.jl +++ /dev/null @@ -1,87 +0,0 @@ -using Test, MPI, Reactant, InteractiveUtils - -Reactant.set_default_backend("cpu") -# Reactant.set_default_backend("gpu") - -MPI.Init() - -comm = MPI.COMM_WORLD -rank = MPI.Comm_rank(comm) -nranks = MPI.Comm_size(comm) - - -# ---------------- -# Send/Recv! in one func -# ---------------- -send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -tag = 43 -function aaa(comm, rank, send_buf, recv_buf, tag) - if rank == 0 - dest = 1 - # ccall(:jl_breakpoint, Cvoid, (Any,), dest) - MPI.Send(send_buf, dest, tag, comm) # kinda hacky, but have to return this otherwise julia optimizes this out - return nothing - elseif rank == 1 - src = 0 - MPI.Recv!(recv_buf, src, tag, comm) - return nothing - end -end - -@jit aaa(comm, rank, send_buf, recv_buf, tag) -println("Rank $rank returned, $(recv_buf==send_buf)") - -# rank==1 && sleep(10) -# println("\n$rank: code_hlo optimize=false:\n", @code_hlo optimize=false aaa(comm, rank, send_buf, recv_buf, tag)) -# println("\n$rank: code_hlo:\n", @code_hlo aaa(comm, rank, send_buf, recv_buf, tag)) -# println("\n$rank: code_xla:\n", @code_xla aaa(comm, rank, send_buf, recv_buf, tag)) -# bbb = @compile aaa(comm, rank, send_buf, recv_buf, tag) -# println("\n$rank: lowered:\n", @code_lowered bbb(comm, rank, send_buf, recv_buf, tag)) -# println("\n$rank: typed:\n", @code_typed bbb(comm, rank, send_buf, recv_buf, tag)) -# println("\n$rank: llvm:\n", @code_llvm bbb(comm, rank, send_buf, recv_buf, tag)) - - -# ---------------- -# Send/Recv! compiled separately -# ---------------- -# # test: works -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# if rank == 0 -# dest = 1 -# @jit MPI.Send(send_buf, dest, tag, comm) -# elseif rank == 1 -# src = 0 -# @jit MPI.Recv!(recv_buf, src, tag, comm) -# println( recv_buf == send_buf ) -# end - - -# debug -# send_buf = ConcreteRArray([1, 2, 3, 4, 5]) -# recv_buf = ConcreteRArray([-1, -2, -3, -4, -5]) -# tag = 43 -# if rank == 0 -# dest = 1 - -# # @jit MPI.Send(send_buf, dest, tag, comm) - -# println("\ncode_hlo optimize=false:\n", @code_hlo optimize=false MPI.Send(send_buf, dest, tag, comm)) -# println("\ncode_hlo:\n", @code_hlo MPI.Send(send_buf, dest, tag, comm)) -# println("\ncode_xla:\n", @code_xla MPI.Send(send_buf, dest, tag, comm)) - -# sss = @compile MPI.Send(send_buf, dest, tag, comm) -# println("\nlowered:\n", @code_lowered sss(send_buf, dest, tag, comm)) -# println("\ntyped:\n", @code_typed sss(send_buf, dest, tag, comm)) -# println("\nllvm:\n", @code_llvm sss(send_buf, dest, tag, comm)) - -# # elseif rank == 1 -# # src = 0 -# # @jit MPI.Recv!(recv_buf, src, tag, comm) -# # println( recv_buf == send_buf ) -# end - - -MPI.Finalize() diff --git a/roman-temp-debug/Project.toml b/roman-temp-debug/Project.toml deleted file mode 100644 index 31b989ee28..0000000000 --- a/roman-temp-debug/Project.toml +++ /dev/null @@ -1,4 +0,0 @@ -[deps] -MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/roman-temp-debug/README.md b/roman-temp-debug/README.md deleted file mode 100644 index 04c00d9276..0000000000 --- a/roman-temp-debug/README.md +++ /dev/null @@ -1 +0,0 @@ -I will delete this before we submit the PR diff --git a/roman-temp-debug/bbb.jl b/roman-temp-debug/bbb.jl deleted file mode 100644 index 2e640a283b..0000000000 --- a/roman-temp-debug/bbb.jl +++ /dev/null @@ -1,14 +0,0 @@ -g(x::Float64, y) = 2x + y -display(g) - -# g(x, y::Float64) = x + 2y -# display(g) - -# println(g(2.0, 3)) - -# println(g(2, 3.0)) - -# println(g(2.0, 3.0)) - -g(x::Number, y) = 2x + y -println(g(2.0, 3)) diff --git a/roman-temp-debug/runtests.sh b/roman-temp-debug/runtests.sh deleted file mode 100755 index c71f11f386..0000000000 --- a/roman-temp-debug/runtests.sh +++ /dev/null @@ -1,22 +0,0 @@ -# ------------------- -# perlmutter -# ------------------- -salloc --nodes 1 --qos interactive --time 04:00:00 --constraint gpu --gpus 4 --account=nstaff - -# Flags from https://github.com/PRONTOLab/GB-25/blob/main/sharding/perlmutter_scaling_test.jl -export JULIA_CUDA_MEMORY_POOL=none -export JULIA_CUDA_USE_COMPAT=false - -# Flag from: https://github.com/PRONTOLab/GB-25/blob/main/sharding/common_submission_generator.jl -export XLA_REACTANT_GPU_MEM_FRACTION=0.9 - -srun -n 2 julia --project ./mpi.jl - -# Then added this flag to srun -srun -n 2 --gpus-per-task=1 julia --project ./mpi.jl - - -# ------------------- -# local laptop -# ------------------- -mpiexec -n 2 julia --project mpi.jl diff --git a/roman-temp-debug/setup.sh b/roman-temp-debug/setup.sh deleted file mode 100644 index 7dd7d6d268..0000000000 --- a/roman-temp-debug/setup.sh +++ /dev/null @@ -1,7 +0,0 @@ -# how I set up a julia project in this directory -# These commands create Project.toml and Manifest.toml -julia ] -activate . -dev /global/homes/r/romanlee/Documents/codes/Reactant.jl -add MPI -add Test From 4899d6d8708e9dc6d7c87e21c250c38b309709c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= <765740+giordano@users.noreply.github.com> Date: Fri, 19 Sep 2025 23:28:28 +0100 Subject: [PATCH 85/97] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantMPIExt/Ops.jl | 15 +++------ ext/ReactantMPIExt/Overrides.jl | 46 ++++++++-------------------- ext/ReactantMPIExt/ReactantMPIExt.jl | 6 +--- test/integration/mpi.jl | 9 +++--- 4 files changed, 22 insertions(+), 54 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index f448918eb3..6c18ba5400 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -338,9 +338,7 @@ function send( count = Reactant.Ops.constant(Int32(length(buf))) enzymexla.jit_call( - IR.Value[ - buf.mlir_data, count.mlir_data, dest.mlir_data, tag.mlir_data - ]; + IR.Value[buf.mlir_data, count.mlir_data, dest.mlir_data, tag.mlir_data]; fn=sym_attr, result_0=IR.Type[], output_operand_aliases=IR.Attribute(IR.Attribute[]), @@ -459,12 +457,7 @@ function recv!( ]) ret = enzymexla.jit_call( - IR.Value[ - recvbuf.mlir_data, - count.mlir_data, - src.mlir_data, - tag.mlir_data, - ]; + IR.Value[recvbuf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data]; fn=sym_attr, result_0=[mlir_type(recvbuf)], output_operand_aliases, @@ -550,7 +543,9 @@ function wait( sym_attr = IR.FlatSymbolRefAttribute(sym_name) # likely isend/irecv will have injected MPI_COMM_WORLD already - IR.tryinject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.tryinject!( + "MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr" + ) IR.inject!("MPI_Wait", "llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32") diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 5b08d81f28..ab3c6baab0 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -32,12 +32,7 @@ function MPI.Wait(req::TracedRequest) end # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` -function MPI.Send( - buf::TracedRArray, - dest::Integer, - tag::Integer, - comm::MPI.Comm -) +function MPI.Send(buf::TracedRArray, dest::Integer, tag::Integer, comm::MPI.Comm) tag = Reactant.Ops.constant(tag) dest = Reactant.Ops.constant(dest) return MPI.Send(buf, dest, tag, comm) @@ -45,10 +40,7 @@ end # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` function MPI.Send( - buf::TracedRArray, - dest::TracedRNumber, - tag::TracedRNumber, - comm::MPI.Comm + buf::TracedRArray, dest::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm ) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" return Ops.send(buf, tag, dest) @@ -56,11 +48,11 @@ end # TODO should we error if other `AbstractRequest` types are passed in? function MPI.Isend( - buf::TracedRArray, - dest::Integer, - tag::Integer, - comm::MPI.Comm, - request::TracedRequest=TracedRequest((), nothing) + buf::TracedRArray, + dest::Integer, + tag::Integer, + comm::MPI.Comm, + request::TracedRequest=TracedRequest((), nothing), ) dest = Reactant.Ops.constant(dest) tag = Reactant.Ops.constant(tag) @@ -72,10 +64,7 @@ end # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` function MPI.Isend( - buf::TracedRArray, - dest::TracedRNumber, - tag::TracedRNumber, - comm::MPI.Comm, + buf::TracedRArray, dest::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm ) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" @@ -103,12 +92,7 @@ end # return Ops.isend(buf, tag, dest) # end -function MPI.Recv!( - buf::TracedRArray, - source::Integer, - tag::Integer, - comm::MPI.Comm -) +function MPI.Recv!(buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm) tag = Reactant.Ops.constant(tag) source = Reactant.Ops.constant(source) return MPI.Recv!(buf, source, tag, comm) @@ -133,10 +117,7 @@ end # TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` function MPI.Recv!( - buf::TracedRArray, - source::TracedRNumber, - tag::TracedRNumber, - comm::MPI.Comm + buf::TracedRArray, source::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm ) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" return Ops.recv!(buf, tag, source) @@ -147,7 +128,7 @@ function MPI.Irecv!( source::Integer, tag::Integer, comm::MPI.Comm, - request::TracedRequest=TracedRequest((), nothing) + request::TracedRequest=TracedRequest((), nothing), ) source = Reactant.Ops.constant(source) tag = Reactant.Ops.constant(tag) @@ -159,10 +140,7 @@ end # TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` function MPI.Irecv!( - buf::TracedRArray, - source::TracedRNumber, - tag::TracedRNumber, - comm::MPI.Comm + buf::TracedRArray, source::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm ) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 49cfdf9d46..0b0d65b6cb 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -222,14 +222,11 @@ function __init__() end end - mutable struct TracedRequest <: MPI.AbstractRequest paths::Tuple mlir_data::Union{Nothing,Reactant.MLIR.IR.Value} - function TracedRequest( - paths::Tuple, mlir_data::Union{Nothing,Reactant.MLIR.IR.Value} - ) + function TracedRequest(paths::Tuple, mlir_data::Union{Nothing,Reactant.MLIR.IR.Value}) if !isnothing(mlir_data) @assert size(Reactant.MLIR.IR.type(mlir_data)) == () end @@ -328,7 +325,6 @@ Reactant.TracedUtils.get_mlir_data(x::TracedRequest) = x.mlir_data # return result_cache[tocopy] # end - include("Ops.jl") include("Overrides.jl") diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index 3f70ab93d3..4de98c8b7e 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -90,9 +90,8 @@ end end end @jit sendrecv!(comm, rank, send_buf, recv_buf, tag) - rank==1 && @test recv_buf == send_buf + rank == 1 && @test recv_buf == send_buf end - end @testset "Isend / Irecv! / Wait" begin @@ -105,12 +104,12 @@ end recv_buf = ConcreteRArray(zeros(5)) tag = 42 function isendirecvwait(send_buf, recv_buf, rank, tag, comm) - if rank==0 + if rank == 0 dest = 1 req = MPI.Isend(send_buf, dest, tag, comm) MPI.Wait(req) return nothing - elseif rank==1 + elseif rank == 1 src = 0 req = MPI.Irecv!(recv_buf, src, tag, comm) MPI.Wait(req) @@ -118,7 +117,7 @@ end end end @jit isendirecvwait(send_buf, recv_buf, rank, tag, comm) - rank==1 && @test recv_buf == send_buf + rank == 1 && @test recv_buf == send_buf end MPI.Finalize() From a4e159f5514fa67536c00788728d1027fb340b2c Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Mon, 22 Sep 2025 15:51:09 -0500 Subject: [PATCH 86/97] Use the already dlopened libmpi --- ext/ReactantMPIExt/ReactantMPIExt.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 0b0d65b6cb..c48a28dc06 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -40,8 +40,7 @@ function __init__() # # when you fix it, remember to possibly make TracedType const again # Reactant.TracedType = Union{Reactant.TracedRArray,Reactant.TracedRNumber,Reactant.MissingTracedValue,TracedRequest} - # TODO maybe it's more efficient if we use `RTLD_NOW` instead of `RTLD_LAZY`? - libmpi_handle = Libdl.dlopen(MPI.API.libmpi, RTLD_LAZY | RTLD_GLOBAL) + libmpi_handle = MPI.API.libmpi_handle # register MPI routines for name in [ From c6def5966c2af4afc77b46353b1b327f1af89a48 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Mon, 22 Sep 2025 18:10:29 -0500 Subject: [PATCH 87/97] Remove function convert_julia_type_to_mpi_datatype(), use MPI.Dataype() instead --- ext/ReactantMPIExt/Ops.jl | 34 +++++----------------------------- 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 6c18ba5400..9c2f114d2d 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -274,30 +274,6 @@ function inject_mpi_datatype!(datatype) end end -function convert_julia_type_to_mpi_datatype(T::Type) - if T === Bool - MPI.C_BOOL - elseif T === Int8 - MPI.INT8_T - elseif T === Int16 - MPI.INT16_T - elseif T === Int32 - MPI.INT32_T - elseif T === Int64 - MPI.INT64_T - elseif T === Float32 - MPI.FLOAT - elseif T === Float64 - MPI.DOUBLE - elseif T === ComplexF32 - MPI.C_FLOAT_COMPLEX - elseif T === ComplexF64 - MPI.C_DOUBLE_COMPLEX - else - throw(ArgumentError("Unknown conversion from $T to a MPI_Datatype")) - end -end - function send( buf::TracedRArray, tag::TracedRNumber, @@ -305,7 +281,7 @@ function send( location=mlir_stacktrace("mpi.send", @__FILE__, @__LINE__), ) T = Reactant.unwrapped_eltype(buf) - mpi_datatype = convert_julia_type_to_mpi_datatype(T) + mpi_datatype = MPI.Datatype(T) mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) sym_name = "enzymexla_wrapper_MPI_Send_$(mpi_datatype_name)" @@ -356,7 +332,7 @@ function isend( location=mlir_stacktrace("mpi.isend", @__FILE__, @__LINE__), ) T = Reactant.unwrapped_eltype(buf) - mpi_datatype = convert_julia_type_to_mpi_datatype(T) + mpi_datatype = MPI.Datatype(T) mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) sym_name = "enzymexla_wrapper_MPI_Isend_$(mpi_datatype_name)" @@ -416,7 +392,7 @@ function recv!( location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__), ) T = Reactant.unwrapped_eltype(recvbuf) - mpi_datatype = convert_julia_type_to_mpi_datatype(T) + mpi_datatype = MPI.Datatype(T) mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) sym_name = "enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name)" @@ -477,7 +453,7 @@ function irecv!( location=mlir_stacktrace("mpi.irecv", @__FILE__, @__LINE__), ) T = Reactant.unwrapped_eltype(buf) - mpi_datatype = convert_julia_type_to_mpi_datatype(T) + mpi_datatype = MPI.Datatype(T) mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) sym_name = "enzymexla_wrapper_MPI_Irecv_$(mpi_datatype_name)" @@ -626,7 +602,7 @@ function allreduce!( op_name = inject_mpi_op!(op) T = Reactant.unwrapped_eltype(sendbuf) - mpi_datatype = convert_julia_type_to_mpi_datatype(T) + mpi_datatype = MPI.Datatype(T) mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") From 53112ca5e85732957a2453225d1f38b4f3b02524 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Mon, 22 Sep 2025 18:29:50 -0500 Subject: [PATCH 88/97] Cleanup/remove some comments --- ext/ReactantMPIExt/ReactantMPIExt.jl | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index c48a28dc06..035c5e0915 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -36,10 +36,6 @@ function Distributed.get_local_process_id(::Distributed.MPIEnvDetector) end function __init__() - # # TODO improve this, temporary hack - # # when you fix it, remember to possibly make TracedType const again - # Reactant.TracedType = Union{Reactant.TracedRArray,Reactant.TracedRNumber,Reactant.MissingTracedValue,TracedRequest} - libmpi_handle = MPI.API.libmpi_handle # register MPI routines @@ -237,21 +233,20 @@ function Base.show(io::IOty, X::TracedRequest) where {IOty<:Union{IO,IOContext}} return print(io, "TracedRequest(", X.paths, ")") end -Reactant.TracedUtils.get_mlir_data(x::TracedRequest) = x.mlir_data +# # NOTE: Commenting out the below on the assumption that a Request will never cross the compile boundary +# # If we ever want to return a request, the below could serve as a starting point +# Reactant.TracedUtils.get_mlir_data(x::TracedRequest) = x.mlir_data # Reactant.TracedUtils.set_mlir_data!(x::TracedRequest, data) = (x.mlir_data = data; return x) -# # May need these later, but for now we assume that no request needs to pass the compile boundary # Reactant.TracedUtils.get_paths(x::TracedRequest) = x.paths # Reactant.TracedUtils.set_paths!(x::TracedRequest, paths) = (x.paths = paths; return x) # -# # TODO not sure how to implement this for TracedRequest -# # probably just want to hardcode the types and dims? # function Reactant.Ops.mlir_type(x::TracedRequest)::MLIR.IR.Type # # return MLIR.IR.TensorType(collect(Int, size(x)), MLIR.IR.Type(unwrapped_eltype(x))) # return MLIR.IR.TensorType(collect(Int, ()), MLIR.IR.Type(Int64)) # end # -# TODO if want to use this, need to somehow add TracedRequest to TracedTypes, which is currently const +# TODO for this to work properly in finalize_mlir_fn(), need to add TracedRequest to TracedTypes, currently const # Base.@nospecializeinfer function Reactant.make_tracer( # seen, # @nospecialize(prev::TracedRequest), From d77a4d9d2ae8fcd86f1adc4e7ac2b0dec413a16e Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Mon, 22 Sep 2025 18:43:31 -0500 Subject: [PATCH 89/97] Cleanup some comments --- ext/ReactantMPIExt/Ops.jl | 2 -- ext/ReactantMPIExt/Overrides.jl | 23 +---------------------- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl index 9c2f114d2d..40c2310956 100644 --- a/ext/ReactantMPIExt/Ops.jl +++ b/ext/ReactantMPIExt/Ops.jl @@ -287,8 +287,6 @@ function send( sym_name = "enzymexla_wrapper_MPI_Send_$(mpi_datatype_name)" sym_attr = IR.FlatSymbolRefAttribute(sym_name) - # TODO emit constant for size and datatype, and pass as args - IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") IR.inject!( "MPI_Send", diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index ab3c6baab0..2ad2f854f3 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -71,34 +71,13 @@ function MPI.Isend( return Ops.isend(buf, tag, dest) end -# TODO possible to use this signature? As is, ambiguous with the ones defined by MPI.jl -# # TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` -# function MPI.Isend( -# buf::TracedRArray, -# dest::Union{T,TracedRNumber{T}}, -# tag::Union{T,TracedRNumber{T}}, -# comm::MPI.Comm, -# ) where {T<:Integer} -# @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" - -# tag = if !(tag isa TracedRNumber) -# Reactant.Ops.constant(tag) -# end - -# dest = if !(dest isa TracedRNumber) -# Reactant.Ops.constant(dest) -# end - -# return Ops.isend(buf, tag, dest) -# end - function MPI.Recv!(buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm) tag = Reactant.Ops.constant(tag) source = Reactant.Ops.constant(source) return MPI.Recv!(buf, source, tag, comm) end -# TODO Do we need these? Comment out at least until everything is working +# TODO Do we need these? # function MPI.Recv!( # buf::TracedRArray, # source::Integer, From 508ed6ac48b0a9a28936ce2be1b52f034fd0cee3 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Tue, 23 Sep 2025 13:46:09 -0500 Subject: [PATCH 90/97] set default backend cpu for mpi tests, then set it back --- test/integration/mpi.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index 4de98c8b7e..294d1ec974 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -1,7 +1,7 @@ using Test, MPI, Reactant -# # MPI only works on cpu currently --- is this the right way/place to enforce that? -# Reactant.set_default_backend("cpu") +client = Reactant.XLA.default_backend() +Reactant.set_default_backend("cpu") MPI.Init() @@ -121,3 +121,5 @@ end end MPI.Finalize() + +Reactant.set_default_backend(client) From 8b4631e682f7c7a10b41f2220a5c6dcce2b88afa Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Tue, 23 Sep 2025 14:29:04 -0500 Subject: [PATCH 91/97] Add test for MPI Barrier --- test/integration/mpi.jl | 74 ++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index 294d1ec974..f36cfdc851 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -24,40 +24,52 @@ end @test nranks == @jit MPI.Allreduce(x, MPI.SUM, MPI.COMM_WORLD) end +@testset "Barrier" begin + @testset "Single Barrier" begin + comm = MPI.COMM_WORLD + ret = @jit MPI.Barrier(comm) + @test ret === nothing + end + + @testset "Consecutive Barriers" begin + # Test multiple consecutive barriers + comm = MPI.COMM_WORLD + for i in 1:3 + @test_nowarn @jit MPI.Barrier(comm) + end + end +end + @testset "Send / Recv!" begin comm = MPI.COMM_WORLD rank = MPI.Comm_rank(comm) nranks = MPI.Comm_size(comm) - # test MPI.jl Send / Reactant Recv - # useful to isolate Reactant issues - @testset "MPI.jl Send / Reactant Recv!" begin - send_buf = ones(5) - tag = 43 - if rank == 0 - MPI.Send(send_buf, comm; dest=1, tag=tag) - elseif rank == 1 - recv_buf = ConcreteRArray(zeros(5)) - source = 0 - @jit MPI.Recv!(recv_buf, source, tag, comm) - @test recv_buf == send_buf - end - end - - # test Reactant Send / MPI.jl Recv - # useful to isolate Reactant issues - @testset "Reactant Send / MPI.jl Recv!" begin - send_buf = ConcreteRArray(ones(5)) - tag = 43 - if rank == 0 - dest = 1 - @jit MPI.Send(send_buf, dest, tag, comm) - elseif rank == 1 - recv_buf = zeros(5) - MPI.Recv!(recv_buf, comm; source=0, tag=tag) - @test recv_buf == send_buf - end - end + # # useful for isolating whether Reactant Send or Recv! is the issue + # @testset "MPI.jl Send / Reactant Recv!" begin + # send_buf = ones(5) + # tag = 43 + # if rank == 0 + # MPI.Send(send_buf, comm; dest=1, tag=tag) + # elseif rank == 1 + # recv_buf = ConcreteRArray(zeros(5)) + # source = 0 + # @jit MPI.Recv!(recv_buf, source, tag, comm) + # @test recv_buf == send_buf + # end + # end + # @testset "Reactant Send / MPI.jl Recv!" begin + # send_buf = ConcreteRArray(ones(5)) + # tag = 43 + # if rank == 0 + # dest = 1 + # @jit MPI.Send(send_buf, dest, tag, comm) + # elseif rank == 1 + # recv_buf = zeros(5) + # MPI.Recv!(recv_buf, comm; source=0, tag=tag) + # @test recv_buf == send_buf + # end + # end # test Reactant Send/Recv @testset "Reactant Send / Recv! - compiled separately" begin @@ -99,7 +111,9 @@ end rank = MPI.Comm_rank(comm) nranks = MPI.Comm_size(comm) - # note: currently don't allow a request to cross the compile boundary + # NOTE: currently don't allow a request to cross the compile boundary + # debugging tip: if this fails, can use pair Send with Irecv! + Wait, or Recv! with + # Isend + Wait to isolate the issue send_buf = ConcreteRArray(ones(5)) recv_buf = ConcreteRArray(zeros(5)) tag = 42 From 36b0d26af646d042b13f520caa79663f74a1bebc Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Tue, 23 Sep 2025 14:32:22 -0500 Subject: [PATCH 92/97] Remove comment --- test/integration/mpi.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl index f36cfdc851..b7e6859dbe 100644 --- a/test/integration/mpi.jl +++ b/test/integration/mpi.jl @@ -32,7 +32,6 @@ end end @testset "Consecutive Barriers" begin - # Test multiple consecutive barriers comm = MPI.COMM_WORLD for i in 1:3 @test_nowarn @jit MPI.Barrier(comm) From 9a312b743f01bfa8d5bbc3aed708f7d1e4e1fdf0 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Tue, 23 Sep 2025 15:41:32 -0500 Subject: [PATCH 93/97] Comment out MPI.Init and MPI.Finalize in Overrides, not implemented yet --- ext/ReactantMPIExt/Overrides.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl index 2ad2f854f3..73638c8e85 100644 --- a/ext/ReactantMPIExt/Overrides.jl +++ b/ext/ReactantMPIExt/Overrides.jl @@ -1,15 +1,15 @@ using Reactant: @reactant_overlay, TracedRArray, TracedRNumber -@reactant_overlay @noinline function MPI.Init(; kwargs...) - if !isempty(kwargs) - @warn "Ignoring MPI.Init kwargs when tracing over MPI..." kwargs... - end - return Ops.init() -end +# @reactant_overlay @noinline function MPI.Init(; kwargs...) +# if !isempty(kwargs) +# @warn "Ignoring MPI.Init kwargs when tracing over MPI..." kwargs... +# end +# return Ops.init() +# end -@reactant_overlay @noinline function MPI.Finalize(; kwargs...) - return Ops.finalize() -end +# @reactant_overlay @noinline function MPI.Finalize(; kwargs...) +# return Ops.finalize() +# end @reactant_overlay @noinline function MPI.Comm_rank(comm::MPI.Comm) @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" From 6feb05d6a595d5ae21c172ccace8ea9c4ac4f690 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Wed, 24 Sep 2025 15:02:50 -0500 Subject: [PATCH 94/97] Don't register MPI_ERR_RMA_RANGE so CI green since not defined for windows --- ext/ReactantMPIExt/ReactantMPIExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 035c5e0915..7ef41b0f3a 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -197,7 +197,7 @@ function __init__() # :MPI_T_ERR_PVAR_NO_STARTSTOP, # :MPI_T_ERR_PVAR_NO_WRITE, # :MPI_T_ERR_PVAR_NO_ATOMIC, - :MPI_ERR_RMA_RANGE, + # :MPI_ERR_RMA_RANGE, # not defined for windows mpi :MPI_ERR_RMA_ATTACH, :MPI_ERR_RMA_FLAVOR, :MPI_ERR_RMA_SHARED, From 1d924908471bdcbd064cc352e769126401f69006 Mon Sep 17 00:00:00 2001 From: Roman Lee <31547765+romanlee@users.noreply.github.com> Date: Wed, 24 Sep 2025 13:14:52 -0700 Subject: [PATCH 95/97] Update ext/ReactantMPIExt/ReactantMPIExt.jl Co-authored-by: Avik Pal --- ext/ReactantMPIExt/ReactantMPIExt.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 7ef41b0f3a..3f9c555e7f 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -208,6 +208,7 @@ function __init__() # :MPI_ERR_PROC_FAILED_PENDING, # :MPI_ERR_REVOKED, ] + !hasproperty(MPI.API, name) && continue value = getproperty(MPI.API, name) if value isa Base.RefValue value = value[] From 17f81448bdbee18a5fb23a9fc5adf62e7a0f2f40 Mon Sep 17 00:00:00 2001 From: Roman Lee Date: Wed, 24 Sep 2025 15:16:14 -0500 Subject: [PATCH 96/97] Revert "Don't register MPI_ERR_RMA_RANGE so CI green since not defined for" This reverts commit 6feb05d6a595d5ae21c172ccace8ea9c4ac4f690. --- ext/ReactantMPIExt/ReactantMPIExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 3f9c555e7f..91f2794b15 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -197,7 +197,7 @@ function __init__() # :MPI_T_ERR_PVAR_NO_STARTSTOP, # :MPI_T_ERR_PVAR_NO_WRITE, # :MPI_T_ERR_PVAR_NO_ATOMIC, - # :MPI_ERR_RMA_RANGE, # not defined for windows mpi + :MPI_ERR_RMA_RANGE, :MPI_ERR_RMA_ATTACH, :MPI_ERR_RMA_FLAVOR, :MPI_ERR_RMA_SHARED, From 8a829ef8462798d0fb5cbb2101294ad75ac1f5ba Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Sep 2025 17:03:11 -0400 Subject: [PATCH 97/97] Update ext/ReactantMPIExt/ReactantMPIExt.jl --- ext/ReactantMPIExt/ReactantMPIExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl index 91f2794b15..9d4dd7a651 100644 --- a/ext/ReactantMPIExt/ReactantMPIExt.jl +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -208,7 +208,7 @@ function __init__() # :MPI_ERR_PROC_FAILED_PENDING, # :MPI_ERR_REVOKED, ] - !hasproperty(MPI.API, name) && continue + !isdefined(MPI.API, name) && continue value = getproperty(MPI.API, name) if value isa Base.RefValue value = value[]