diff --git a/ext/ReactantMPIExt.jl b/ext/ReactantMPIExt.jl deleted file mode 100644 index 5ede919c0d..0000000000 --- a/ext/ReactantMPIExt.jl +++ /dev/null @@ -1,36 +0,0 @@ -module ReactantMPIExt - -using Reactant: Reactant, Distributed -using MPI: MPI - -# https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py -Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized() - -function Distributed.get_coordinator_address( - ::Distributed.MPIEnvDetector, timeout_in_seconds::Integer -) - if MPI.Comm_rank(MPI.COMM_WORLD) == 0 - hostname = gethostname() - port_id = hash(hostname) % 2^12 + (65535 - 2^12 + 1) - hostname = "$(hostname):$(port_id)" - else - hostname = nothing - end - - return MPI.bcast(hostname, MPI.COMM_WORLD; root=0) -end - -function Distributed.get_process_count(::Distributed.MPIEnvDetector) - return Int(MPI.Comm_size(MPI.COMM_WORLD)) -end - -function Distributed.get_process_id(::Distributed.MPIEnvDetector) - return Int(MPI.Comm_rank(MPI.COMM_WORLD)) -end - -function Distributed.get_local_process_id(::Distributed.MPIEnvDetector) - new_comm = MPI.Comm_split_type(MPI.COMM_WORLD, MPI.COMM_TYPE_SHARED, 0) - return Int(MPI.Comm_rank(new_comm)) -end - -end diff --git a/ext/ReactantMPIExt/Ops.jl b/ext/ReactantMPIExt/Ops.jl new file mode 100644 index 0000000000..40c2310956 --- /dev/null +++ b/ext/ReactantMPIExt/Ops.jl @@ -0,0 +1,654 @@ +module Ops +using Reactant: Reactant, TracedRArray, TracedRNumber +using Reactant: MLIR +using Reactant.MLIR: IR +using Reactant.MLIR.IR: @mlir_str +using Reactant.MLIR.Dialects: mpi, func, llvm, enzymexla +using Reactant.Ops: mlir_stacktrace, mlir_type +using ..ReactantMPIExt: TracedRequest +using MPI: MPI + +# TODO we might need to have a `TracedComm` for communicators created during the compiled function + +# function init(; location=mlir_stacktrace("mpi.init", @__FILE__, @__LINE__)) +# return mpi.init(; location) +# end + +# function finalize(; location=mlir_stacktrace("mpi.finalize", @__FILE__, @__LINE__)) +# return mpi.finalize(; location) +# end + +function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__)) + sym_name = "enzymexla_wrapper_MPI_Comm_rank" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!("MPI_Comm_rank", "llvm.func @MPI_Comm_rank(!llvm.ptr, !llvm.ptr) -> i32") + + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%rank_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %errcode = llvm.call @MPI_Comm_rank(%comm, %rank_ptr) : (!llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + rank_placeholder = Reactant.Ops.constant(fill(Cint(-1))) + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + ), + ), + ]) + + res = IR.result( + enzymexla.jit_call( + IR.Value[rank_placeholder.mlir_data]; + fn=sym_attr, + result_0=[IR.TensorType(Int[], IR.Type(Cint))], + location, + output_operand_aliases, + ), + ) + return TracedRNumber{Cint}((), res) +end + +function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__)) + sym_name = "enzymexla_wrapper_MPI_Comm_size" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!("MPI_Comm_size", "llvm.func @MPI_Comm_size(!llvm.ptr, !llvm.ptr) -> i32") + + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%size_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %errcode = llvm.call @MPI_Comm_size(%comm, %size_ptr) : (!llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + size_placeholder = Reactant.Ops.constant(fill(Cint(-1))) + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + ), + ), + ]) + + res = IR.result( + enzymexla.jit_call( + IR.Value[size_placeholder.mlir_data]; + fn=sym_attr, + result_0=[IR.TensorType(Int[], IR.Type(Cint))], + output_operand_aliases, + location, + ), + ) + return TracedRNumber{Cint}((), res) +end + +function barrier(; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__)) + sym_name = "enzymexla_wrapper_MPI_Barrier" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!("MPI_Barrier", "llvm.func @MPI_Barrier(!llvm.ptr) -> i32") + + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name() -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %status = llvm.call @MPI_Barrier(%comm) : (!llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + output_operand_aliases = IR.Attribute(IR.Attribute[]) + enzymexla.jit_call( + IR.Value[]; fn=sym_attr, result_0=IR.Type[], output_operand_aliases, location + ) + + return nothing +end + +function inject_mpi_datatype!(datatype) + if datatype == MPI.DATATYPE_NULL + IR.inject!( + "MPI_DATATYPE_NULL", + "llvm.mlir.global constant @MPI_DATATYPE_NULL() : !llvm.ptr", + ) + return "MPI_DATATYPE_NULL" + elseif datatype == MPI.BYTE + IR.inject!("MPI_BYTE", "llvm.mlir.global constant @MPI_BYTE() : !llvm.ptr") + return "MPI_BYTE" + # elseif datatype == MPI.PACKED + # IR.inject!("MPI_PACKED", "llvm.mlir.global constant @MPI_PACKED() : !llvm.ptr") + # return "MPI_PACKED" + elseif datatype == MPI.CHAR + IR.inject!("MPI_CHAR", "llvm.mlir.global constant @MPI_CHAR() : !llvm.ptr") + return "MPI_CHAR" + elseif datatype == MPI.SHORT + IR.inject!("MPI_SHORT", "llvm.mlir.global constant @MPI_SHORT() : !llvm.ptr") + return "MPI_SHORT" + elseif datatype == MPI.INT + IR.inject!("MPI_INT", "llvm.mlir.global constant @MPI_INT() : !llvm.ptr") + return "MPI_INT" + elseif datatype == MPI.LONG + IR.inject!("MPI_LONG", "llvm.mlir.global constant @MPI_LONG() : !llvm.ptr") + return "MPI_LONG" + elseif datatype == MPI.FLOAT + IR.inject!("MPI_FLOAT", "llvm.mlir.global constant @MPI_FLOAT() : !llvm.ptr") + return "MPI_FLOAT" + elseif datatype == MPI.DOUBLE + IR.inject!("MPI_DOUBLE", "llvm.mlir.global constant @MPI_DOUBLE() : !llvm.ptr") + return "MPI_DOUBLE" + elseif datatype == MPI.UNSIGNED_CHAR + IR.inject!( + "MPI_UNSIGNED_CHAR", + "llvm.mlir.global constant @MPI_UNSIGNED_CHAR() : !llvm.ptr", + ) + return "MPI_UNSIGNED_CHAR" + elseif datatype == MPI.SIGNED_CHAR + IR.inject!( + "MPI_SIGNED_CHAR", "llvm.mlir.global constant @MPI_SIGNED_CHAR() : !llvm.ptr" + ) + return "MPI_SIGNED_CHAR" + elseif datatype == MPI.UNSIGNED_SHORT + IR.inject!( + "MPI_UNSIGNED_SHORT", + "llvm.mlir.global constant @MPI_UNSIGNED_SHORT() : !llvm.ptr", + ) + return "MPI_UNSIGNED_SHORT" + elseif datatype == MPI.UNSIGNED_LONG + IR.inject!( + "MPI_UNSIGNED_LONG", + "llvm.mlir.global constant @MPI_UNSIGNED_LONG() : !llvm.ptr", + ) + return "MPI_UNSIGNED_LONG" + elseif datatype == MPI.UNSIGNED + IR.inject!("MPI_UNSIGNED", "llvm.mlir.global constant @MPI_UNSIGNED() : !llvm.ptr") + return "MPI_UNSIGNED" + # elseif datatype == MPI.FLOAT_INT + # IR.inject!( + # "MPI_FLOAT_INT", "llvm.mlir.global constant @MPI_FLOAT_INT() : !llvm.ptr" + # ) + # return "MPI_FLOAT_INT" + # elseif datatype == MPI.DOUBLE_INT + # IR.inject!( + # "MPI_DOUBLE_INT", "llvm.mlir.global constant @MPI_DOUBLE_INT() : !llvm.ptr" + # ) + # return "MPI_DOUBLE_INT" + # elseif datatype == MPI.LONG_DOUBLE_INT + # IR.inject!( + # "MPI_LONG_DOUBLE_INT", + # "llvm.mlir.global constant @MPI_LONG_DOUBLE_INT() : !llvm.ptr", + # ) + # return "MPI_LONG_DOUBLE_INT" + # elseif datatype == MPI.LONG_INT + # IR.inject!("MPI_LONG_INT", "llvm.mlir.global constant @MPI_LONG_INT() : !llvm.ptr") + # return "MPI_LONG_INT" + # elseif datatype == MPI.SHORT_INT + # IR.inject!( + # "MPI_SHORT_INT", "llvm.mlir.global constant @MPI_SHORT_INT() : !llvm.ptr" + # ) + # return "MPI_SHORT_INT" + # elseif datatype == MPI.UB + # IR.inject!("MPI_UB", "llvm.mlir.global constant @MPI_UB() : !llvm.ptr") + # return "MPI_UB" + # elseif datatype == MPI.LB + # IR.inject!("MPI_LB", "llvm.mlir.global constant @MPI_LB() : !llvm.ptr") + # return "MPI_LB" + elseif datatype == MPI.WCHAR + IR.inject!("MPI_WCHAR", "llvm.mlir.global constant @MPI_WCHAR() : !llvm.ptr") + return "MPI_WCHAR" + elseif datatype == MPI.LONG_LONG_INT + IR.inject!( + "MPI_LONG_LONG_INT", + "llvm.mlir.global constant @MPI_LONG_LONG_INT() : !llvm.ptr", + ) + return "MPI_LONG_LONG_INT" + elseif datatype == MPI.UNSIGNED_LONG_LONG + IR.inject!( + "MPI_UNSIGNED_LONG_LONG", + "llvm.mlir.global constant @MPI_UNSIGNED_LONG_LONG() : !llvm.ptr", + ) + return "MPI_UNSIGNED_LONG_LONG" + elseif datatype == MPI.INT8_T + IR.inject!("MPI_INT8_T", "llvm.mlir.global constant @MPI_INT8_T() : !llvm.ptr") + return "MPI_INT8_T" + elseif datatype == MPI.UINT8_T + IR.inject!("MPI_UINT8_T", "llvm.mlir.global constant @MPI_UINT8_T() : !llvm.ptr") + return "MPI_UINT8_T" + elseif datatype == MPI.INT16_T + IR.inject!("MPI_INT16_T", "llvm.mlir.global constant @MPI_INT16_T() : !llvm.ptr") + return "MPI_INT16_T" + elseif datatype == MPI.UINT16_T + IR.inject!("MPI_UINT16_T", "llvm.mlir.global constant @MPI_UINT16_T() : !llvm.ptr") + return "MPI_UINT16_T" + elseif datatype == MPI.INT32_T + IR.inject!("MPI_INT32_T", "llvm.mlir.global constant @MPI_INT32_T() : !llvm.ptr") + return "MPI_INT32_T" + elseif datatype == MPI.UINT32_T + IR.inject!("MPI_UINT32_T", "llvm.mlir.global constant @MPI_UINT32_T() : !llvm.ptr") + return "MPI_UINT32_T" + elseif datatype == MPI.INT64_T + IR.inject!("MPI_INT64_T", "llvm.mlir.global constant @MPI_INT64_T() : !llvm.ptr") + return "MPI_INT64_T" + elseif datatype == MPI.UINT64_T + IR.inject!("MPI_UINT64_T", "llvm.mlir.global constant @MPI_UINT64_T() : !llvm.ptr") + return "MPI_UINT64_T" + elseif datatype == MPI.AINT + IR.inject!("MPI_AINT", "llvm.mlir.global constant @MPI_AINT() : !llvm.ptr") + return "MPI_AINT" + elseif datatype == MPI.OFFSET + IR.inject!("MPI_OFFSET", "llvm.mlir.global constant @MPI_OFFSET() : !llvm.ptr") + return "MPI_OFFSET" + elseif datatype == MPI.C_BOOL + IR.inject!("MPI_C_BOOL", "llvm.mlir.global constant @MPI_C_BOOL() : !llvm.ptr") + return "MPI_C_BOOL" + elseif datatype == MPI.C_FLOAT_COMPLEX + IR.inject!( + "MPI_C_FLOAT_COMPLEX", + "llvm.mlir.global constant @MPI_C_FLOAT_COMPLEX() : !llvm.ptr", + ) + return "MPI_C_FLOAT_COMPLEX" + elseif datatype == MPI.C_DOUBLE_COMPLEX + IR.inject!( + "MPI_C_DOUBLE_COMPLEX", + "llvm.mlir.global constant @MPI_C_DOUBLE_COMPLEX() : !llvm.ptr", + ) + return "MPI_C_DOUBLE_COMPLEX" + elseif datatype == MPI.COUNT + IR.inject!("MPI_COUNT", "llvm.mlir.global constant @MPI_COUNT() : !llvm.ptr") + return "MPI_COUNT" + else + throw(ArgumentError("Unknown MPI datatype `$datatype`")) + end +end + +function send( + buf::TracedRArray, + tag::TracedRNumber, + dest::TracedRNumber; + location=mlir_stacktrace("mpi.send", @__FILE__, @__LINE__), +) + T = Reactant.unwrapped_eltype(buf) + mpi_datatype = MPI.Datatype(T) + mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) + + sym_name = "enzymexla_wrapper_MPI_Send_$(mpi_datatype_name)" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!( + "MPI_Send", + "llvm.func @MPI_Send(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> i32", + ) + + # int MPI_Send(const void* buf, int count, MPI_Datatype datatype, + # int dest, int tag, MPI_Comm comm) + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %dest_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %datatype = llvm.mlir.addressof @$(mpi_datatype_name) : !llvm.ptr + %count = llvm.load %count_ptr : !llvm.ptr -> i32 + %dest = llvm.load %dest_ptr : !llvm.ptr -> i32 + %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 + llvm.call @MPI_Send(%buf, %count, %datatype, %dest, %tag, %comm) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + count = Reactant.Ops.constant(Int32(length(buf))) + + enzymexla.jit_call( + IR.Value[buf.mlir_data, count.mlir_data, dest.mlir_data, tag.mlir_data]; + fn=sym_attr, + result_0=IR.Type[], + output_operand_aliases=IR.Attribute(IR.Attribute[]), + location, + ) + + return nothing +end + +# TODO need c-function for creating MLIR `mpi.request` type? +function isend( + buf::TracedRArray, + tag::TracedRNumber, + dest::TracedRNumber; + location=mlir_stacktrace("mpi.isend", @__FILE__, @__LINE__), +) + T = Reactant.unwrapped_eltype(buf) + mpi_datatype = MPI.Datatype(T) + mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) + + sym_name = "enzymexla_wrapper_MPI_Isend_$(mpi_datatype_name)" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!( + "MPI_Isend", + "llvm.func @MPI_Isend(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32", + ) + + # int MPI_Isend(const void* buf, int count, MPI_Datatype datatype, + # int dest, int tag, MPI_Comm comm, MPI_Request* request) + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %dest_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr, %req_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %datatype = llvm.mlir.addressof @$(mpi_datatype_name) : !llvm.ptr + %count = llvm.load %count_ptr : !llvm.ptr -> i32 + %dest = llvm.load %dest_ptr : !llvm.ptr -> i32 + %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 + %res = llvm.call @MPI_Isend(%buf, %count, %datatype, %dest, %tag, %comm, %req_ptr) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + count = Reactant.Ops.constant(Int32(length(buf))) + request = Reactant.Ops.constant(Int64(-1)) + + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 4, 0, C_NULL + ), + ), + ]) + + ret = enzymexla.jit_call( + IR.Value[ + buf.mlir_data, count.mlir_data, dest.mlir_data, tag.mlir_data, request.mlir_data + ]; + fn=sym_attr, + result_0=IR.Type[mlir_type(request)], + output_operand_aliases=output_operand_aliases, + location, + ) + + request.mlir_data = IR.result(ret) + return request # we return a TracedRNumber, converted to TracedRequest in Overrides.jl +end + +function recv!( + recvbuf::TracedRArray, + tag::TracedRNumber, + src::TracedRNumber; + location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__), +) + T = Reactant.unwrapped_eltype(recvbuf) + mpi_datatype = MPI.Datatype(T) + mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) + + sym_name = "enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name)" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!( + "MPI_STATUS_IGNORE", "llvm.mlir.global constant @MPI_STATUS_IGNORE() : !llvm.ptr" + ) + IR.inject!( + "MPI_Recv", + "llvm.func @MPI_Recv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32", + ) + + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %source_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr + %status = llvm.mlir.addressof @MPI_STATUS_IGNORE : !llvm.ptr + %count = llvm.load %count_ptr : !llvm.ptr -> i32 + %source = llvm.load %source_ptr : !llvm.ptr -> i32 + %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 + llvm.call @MPI_Recv(%buf, %count, %datatype, %source, %tag, %comm, %status) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + count = Reactant.Ops.constant(Int32(length(recvbuf))) + + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 0, 0, C_NULL + ), + ), + ]) + + ret = enzymexla.jit_call( + IR.Value[recvbuf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data]; + fn=sym_attr, + result_0=[mlir_type(recvbuf)], + output_operand_aliases, + location, + ) + + recvbuf.mlir_data = IR.result(ret) + + return recvbuf +end + +# TODO need c-function for creating MLIR `mpi.request` type? +function irecv!( + buf::TracedRArray, + tag::TracedRNumber, + src::TracedRNumber; + location=mlir_stacktrace("mpi.irecv", @__FILE__, @__LINE__), +) + T = Reactant.unwrapped_eltype(buf) + mpi_datatype = MPI.Datatype(T) + mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) + + sym_name = "enzymexla_wrapper_MPI_Irecv_$(mpi_datatype_name)" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!( + "MPI_Irecv", + "llvm.func @MPI_Irecv(!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> i32", + ) + + # int MPI_Irecv(void* buf, int count, MPI_Datatype datatype, + # int source, int tag, MPI_Comm comm, MPI_Request* request) + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%buf : !llvm.ptr, %count_ptr : !llvm.ptr, %src_ptr : !llvm.ptr, %tag_ptr : !llvm.ptr, %req_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %datatype = llvm.mlir.addressof @$(mpi_datatype_name) : !llvm.ptr + %count = llvm.load %count_ptr : !llvm.ptr -> i32 + %src = llvm.load %src_ptr : !llvm.ptr -> i32 + %tag = llvm.load %tag_ptr : !llvm.ptr -> i32 + %res = llvm.call @MPI_Irecv(%buf, %count, %datatype, %src, %tag, %comm, %req_ptr) : (!llvm.ptr, i32, !llvm.ptr, i32, i32, !llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + count = Reactant.Ops.constant(Int32(length(buf))) + request = Reactant.Ops.constant(Int64(-1)) + + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 1, Ref{Int64}(0), 0, 0, C_NULL + ), + ), + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 1, Ref{Int64}(1), 4, 0, C_NULL + ), + ), + ]) + + ret = enzymexla.jit_call( + IR.Value[ + buf.mlir_data, count.mlir_data, src.mlir_data, tag.mlir_data, request.mlir_data + ]; + fn=sym_attr, + result_0=[mlir_type(buf), mlir_type(request)], + output_operand_aliases=output_operand_aliases, + location, + ) + + buf.mlir_data = IR.result(ret, 1) + request.mlir_data = IR.result(ret, 2) + return request # we return a TracedRNumber, converted to TracedRequest in Overrides.jl +end + +function wait( + req::TracedRequest; location=mlir_stacktrace("mpi.wait", @__FILE__, @__LINE__) +) + sym_name = "enzymexla_wrapper_MPI_Wait" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + # likely isend/irecv will have injected MPI_COMM_WORLD already + IR.tryinject!( + "MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr" + ) + + IR.inject!("MPI_Wait", "llvm.func @MPI_Wait(!llvm.ptr, !llvm.ptr) -> i32") + + # NOTE: Size of status is implem dependent, we try to set it to the max + # int MPI_Wait(MPI_Request* request, MPI_Status* status) + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%req : !llvm.ptr) -> () { + %c1_i32 = arith.constant 1 : i32 + %status = llvm.alloca %c1_i32 x !llvm.array<6 x i32> : (i32) -> !llvm.ptr + llvm.call @MPI_Wait(%req, %status) : (!llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + enzymexla.jit_call( + IR.Value[req.mlir_data]; + fn=sym_attr, + result_0=IR.Type[], + location, + output_operand_aliases=IR.Attribute(IR.Attribute[]), + ) + + return nothing +end + +function inject_mpi_op!(op) + if op == MPI.OP_NULL + IR.inject!("MPI_OP_NULL", "llvm.mlir.global constant @MPI_OP_NULL() : !llvm.ptr") + return "MPI_OP_NULL" + elseif op == MPI.MAX + IR.inject!("MPI_MAX", "llvm.mlir.global constant @MPI_MAX() : !llvm.ptr") + return "MPI_MAX" + elseif op == MPI.MIN + IR.inject!("MPI_MIN", "llvm.mlir.global constant @MPI_MIN() : !llvm.ptr") + return "MPI_MIN" + elseif op == MPI.SUM + IR.inject!("MPI_SUM", "llvm.mlir.global constant @MPI_SUM() : !llvm.ptr") + return "MPI_SUM" + elseif op == MPI.PROD + IR.inject!("MPI_PROD", "llvm.mlir.global constant @MPI_PROD() : !llvm.ptr") + return "MPI_PROD" + elseif op == MPI.LAND + IR.inject!("MPI_LAND", "llvm.mlir.global constant @MPI_LAND() : !llvm.ptr") + return "MPI_LAND" + elseif op == MPI.BAND + IR.inject!("MPI_BAND", "llvm.mlir.global constant @MPI_BAND() : !llvm.ptr") + return "MPI_BAND" + elseif op == MPI.LOR + IR.inject!("MPI_LOR", "llvm.mlir.global constant @MPI_LOR() : !llvm.ptr") + return "MPI_LOR" + elseif op == MPI.BOR + IR.inject!("MPI_BOR", "llvm.mlir.global constant @MPI_BOR() : !llvm.ptr") + return "MPI_BOR" + elseif op == MPI.LXOR + IR.inject!("MPI_LXOR", "llvm.mlir.global constant @MPI_LXOR() : !llvm.ptr") + return "MPI_LXOR" + elseif op == MPI.BXOR + IR.inject!("MPI_BXOR", "llvm.mlir.global constant @MPI_BXOR() : !llvm.ptr") + return "MPI_BXOR" + elseif op == MPI.REPLACE + IR.inject!("MPI_REPLACE", "llvm.mlir.global constant @MPI_REPLACE() : !llvm.ptr") + return "MPI_REPLACE" + elseif op == MPI.NO_OP + IR.inject!("MPI_NO_OP", "llvm.mlir.global constant @MPI_NO_OP() : !llvm.ptr") + return "MPI_NO_OP" + else + throw(ArgumentError("Unknown MPI operation `$op`")) + end +end + +function allreduce!( + op, sendbuf, recvbuf; location=mlir_stacktrace("mpi.wait", @__FILE__, @__LINE__) +) + @assert Reactant.unwrapped_eltype(sendbuf) == Reactant.unwrapped_eltype(recvbuf) + @assert length(sendbuf) == length(recvbuf) + + op_name = inject_mpi_op!(op) + T = Reactant.unwrapped_eltype(sendbuf) + mpi_datatype = MPI.Datatype(T) + mpi_datatype_name = inject_mpi_datatype!(mpi_datatype) + + IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr") + IR.inject!( + "MPI_Allreduce", + "llvm.func @MPI_Allreduce(!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> i32", + ) + + sym_name = "enzymexla_wrapper_MPI_Allreduce_$(op_name)_$(mpi_datatype_name)" + sym_attr = IR.FlatSymbolRefAttribute(sym_name) + + # TODO is okay to use `i32`? how can we use word-size value or map C's `int` to MLIR? can we use `index`? + #! format: off + IR.inject!(sym_name, """ + func.func @$sym_name(%sendbuf : !llvm.ptr, %recvbuf : !llvm.ptr, %count_ptr : !llvm.ptr) -> () { + %comm = llvm.mlir.addressof @MPI_COMM_WORLD : !llvm.ptr + %op = llvm.mlir.addressof @$op_name : !llvm.ptr + %datatype = llvm.mlir.addressof @$mpi_datatype_name : !llvm.ptr + %count = llvm.load %count_ptr : !llvm.ptr -> i32 + %errcode = llvm.call @MPI_Allreduce(%sendbuf, %recvbuf, %count, %datatype, %op, %comm) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> (i32) + func.return + } + """) + #! format: on + + count = Reactant.Ops.constant(fill(length(sendbuf))) + + output_operand_aliases = IR.Attribute([ + IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, 1, 0, C_NULL + ), + ), + ]) + + res = IR.result( + enzymexla.jit_call( + IR.Value[sendbuf.mlir_data, recvbuf.mlir_data, count.mlir_data]; + fn=sym_attr, + result_0=IR.Type[Reactant.Ops.mlir_type(typeof(recvbuf), size(recvbuf))], + location, + output_operand_aliases, + ), + ) + + recvbuf.mlir_data = res + + return recvbuf +end + +end # module diff --git a/ext/ReactantMPIExt/Overrides.jl b/ext/ReactantMPIExt/Overrides.jl new file mode 100644 index 0000000000..73638c8e85 --- /dev/null +++ b/ext/ReactantMPIExt/Overrides.jl @@ -0,0 +1,132 @@ +using Reactant: @reactant_overlay, TracedRArray, TracedRNumber + +# @reactant_overlay @noinline function MPI.Init(; kwargs...) +# if !isempty(kwargs) +# @warn "Ignoring MPI.Init kwargs when tracing over MPI..." kwargs... +# end +# return Ops.init() +# end + +# @reactant_overlay @noinline function MPI.Finalize(; kwargs...) +# return Ops.finalize() +# end + +@reactant_overlay @noinline function MPI.Comm_rank(comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.comm_rank() +end + +@reactant_overlay @noinline function MPI.Comm_size(comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.comm_size() +end + +@reactant_overlay @noinline function MPI.Barrier(comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.barrier() +end + +# TODO status not supported yet +function MPI.Wait(req::TracedRequest) + return Ops.wait(req) +end + +# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` +function MPI.Send(buf::TracedRArray, dest::Integer, tag::Integer, comm::MPI.Comm) + tag = Reactant.Ops.constant(tag) + dest = Reactant.Ops.constant(dest) + return MPI.Send(buf, dest, tag, comm) +end + +# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` +function MPI.Send( + buf::TracedRArray, dest::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm +) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.send(buf, tag, dest) +end + +# TODO should we error if other `AbstractRequest` types are passed in? +function MPI.Isend( + buf::TracedRArray, + dest::Integer, + tag::Integer, + comm::MPI.Comm, + request::TracedRequest=TracedRequest((), nothing), +) + dest = Reactant.Ops.constant(dest) + tag = Reactant.Ops.constant(tag) + + gen_request = MPI.Isend(buf, dest, tag, comm) + request.mlir_data = gen_request.mlir_data + return request +end + +# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer` +function MPI.Isend( + buf::TracedRArray, dest::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm +) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + + return Ops.isend(buf, tag, dest) +end + +function MPI.Recv!(buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm) + tag = Reactant.Ops.constant(tag) + source = Reactant.Ops.constant(source) + return MPI.Recv!(buf, source, tag, comm) +end + +# TODO Do we need these? +# function MPI.Recv!( +# buf::TracedRArray, +# source::Integer, +# tag::Integer, +# comm::MPI.Comm, +# ::Type{MPI.API.MPI_Status}, +# ) +# return MPI.Recv!(buf, source, tag, comm) +# end + +# function MPI.Recv!( +# buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm, ::Nothing +# ) +# return MPI.Recv!(buf, source, tag, comm) +# end + +# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` +function MPI.Recv!( + buf::TracedRArray, source::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm +) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.recv!(buf, tag, source) +end + +function MPI.Irecv!( + buf::TracedRArray, + source::Integer, + tag::Integer, + comm::MPI.Comm, + request::TracedRequest=TracedRequest((), nothing), +) + source = Reactant.Ops.constant(source) + tag = Reactant.Ops.constant(tag) + + gen_request = MPI.Irecv!(buf, source, tag, comm) + request.mlir_data = gen_request.mlir_data + return request +end + +# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer` +function MPI.Irecv!( + buf::TracedRArray, source::TracedRNumber, tag::TracedRNumber, comm::MPI.Comm +) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + + return Ops.irecv!(buf, tag, source) +end + +function MPI.Allreduce!(sendbuf::TracedRArray, recvbuf::TracedRArray, op, comm::MPI.Comm) + @assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently" + return Ops.allreduce!(op, sendbuf, recvbuf) +end diff --git a/ext/ReactantMPIExt/ReactantMPIExt.jl b/ext/ReactantMPIExt/ReactantMPIExt.jl new file mode 100644 index 0000000000..9d4dd7a651 --- /dev/null +++ b/ext/ReactantMPIExt/ReactantMPIExt.jl @@ -0,0 +1,326 @@ +module ReactantMPIExt + +using Reactant +using Reactant: Reactant, Distributed, MLIR +using MPI: MPI +using Libdl + +# https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py +Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized() + +function Distributed.get_coordinator_address( + ::Distributed.MPIEnvDetector, timeout_in_seconds::Integer +) + if MPI.Comm_rank(MPI.COMM_WORLD) == 0 + hostname = gethostname() + port_id = hash(hostname) % 2^12 + (65535 - 2^12 + 1) + hostname = "$(hostname):$(port_id)" + else + hostname = nothing + end + + return MPI.bcast(hostname, MPI.COMM_WORLD; root=0) +end + +function Distributed.get_process_count(::Distributed.MPIEnvDetector) + return Int(MPI.Comm_size(MPI.COMM_WORLD)) +end + +function Distributed.get_process_id(::Distributed.MPIEnvDetector) + return Int(MPI.Comm_rank(MPI.COMM_WORLD)) +end + +function Distributed.get_local_process_id(::Distributed.MPIEnvDetector) + new_comm = MPI.Comm_split_type(MPI.COMM_WORLD, MPI.COMM_TYPE_SHARED, 0) + return Int(MPI.Comm_rank(new_comm)) +end + +function __init__() + libmpi_handle = MPI.API.libmpi_handle + + # register MPI routines + for name in [ + :MPI_Init, + :MPI_Finalize, + :MPI_Comm_rank, + :MPI_Comm_size, + :MPI_Send, + :MPI_Recv, + :MPI_Isend, + :MPI_Irecv, + :MPI_Barrier, + :MPI_Wait, + :MPI_Request_free, + ] + sym = Libdl.dlsym(libmpi_handle, name) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, sym::Ptr{Cvoid})::Cvoid + end + + # register MPI constants + # NOTE these symbols are not ABI-stable until MPI 5.0, but in practice, they are represented as word-size values (i.e. `int` or ptr) + for name in [ + # communicators + :MPI_COMM_WORLD, + :MPI_COMM_SELF, + :MPI_COMM_NULL, + :MPI_COMM_TYPE_SHARED, + # datatypes + :MPI_DATATYPE_NULL, + :MPI_BYTE, + :MPI_PACKED, + :MPI_CHAR, + :MPI_SHORT, + :MPI_INT, + :MPI_LONG, + :MPI_FLOAT, + :MPI_DOUBLE, + :MPI_UNSIGNED_CHAR, + :MPI_SIGNED_CHAR, + :MPI_UNSIGNED_SHORT, + :MPI_UNSIGNED_LONG, + :MPI_UNSIGNED, + :MPI_FLOAT_INT, + :MPI_DOUBLE_INT, + :MPI_LONG_DOUBLE_INT, + :MPI_LONG_INT, + :MPI_SHORT_INT, + # :MPI_2INT, + :MPI_UB, + :MPI_LB, + :MPI_WCHAR, + :MPI_LONG_LONG_INT, + :MPI_UNSIGNED_LONG_LONG, + # :MPI_2COMPLEX, + # :MPI_2DOUBLE_COMPLEX, + :MPI_INT8_T, + :MPI_UINT8_T, + :MPI_INT16_T, + :MPI_UINT16_T, + :MPI_INT32_T, + :MPI_UINT32_T, + :MPI_INT64_T, + :MPI_UINT64_T, + :MPI_AINT, + :MPI_OFFSET, + :MPI_C_BOOL, + :MPI_C_FLOAT_COMPLEX, + :MPI_C_DOUBLE_COMPLEX, + # :MPI_C_LONG_DOUBLE_COMPLEX, + :MPI_COUNT, + # ops + :MPI_OP_NULL, + :MPI_MAX, + :MPI_MIN, + :MPI_SUM, + :MPI_PROD, + :MPI_LAND, + :MPI_BAND, + :MPI_LOR, + :MPI_BOR, + :MPI_LXOR, + :MPI_BXOR, + :MPI_MINLOC, + :MPI_MAXLOC, + :MPI_REPLACE, + :MPI_NO_OP, + # request + :MPI_REQUEST_NULL, + # status + :MPI_STATUS_IGNORE, + :MPI_STATUSES_IGNORE, + # error + :MPI_SUCCESS, + :MPI_ERR_BUFFER, + :MPI_ERR_COUNT, + :MPI_ERR_TYPE, + :MPI_ERR_TAG, + :MPI_ERR_COMM, + :MPI_ERR_RANK, + :MPI_ERR_REQUEST, + :MPI_ERR_ROOT, + :MPI_ERR_GROUP, + :MPI_ERR_OP, + :MPI_ERR_TOPOLOGY, + :MPI_ERR_DIMS, + :MPI_ERR_ARG, + :MPI_ERR_UNKNOWN, + :MPI_ERR_TRUNCATE, + :MPI_ERR_OTHER, + :MPI_ERR_INTERN, + :MPI_ERR_IN_STATUS, + :MPI_ERR_PENDING, + :MPI_ERR_ACCESS, + :MPI_ERR_AMODE, + :MPI_ERR_ASSERT, + :MPI_ERR_BAD_FILE, + :MPI_ERR_BASE, + :MPI_ERR_CONVERSION, + :MPI_ERR_DISP, + :MPI_ERR_DUP_DATAREP, + :MPI_ERR_FILE_EXISTS, + :MPI_ERR_FILE_IN_USE, + :MPI_ERR_FILE, + :MPI_ERR_INFO_KEY, + :MPI_ERR_INFO_NOKEY, + :MPI_ERR_INFO_VALUE, + :MPI_ERR_INFO, + :MPI_ERR_IO, + :MPI_ERR_KEYVAL, + :MPI_ERR_LOCKTYPE, + :MPI_ERR_NAME, + :MPI_ERR_NO_MEM, + :MPI_ERR_NOT_SAME, + :MPI_ERR_NO_SPACE, + :MPI_ERR_NO_SUCH_FILE, + :MPI_ERR_PORT, + :MPI_ERR_QUOTA, + :MPI_ERR_READ_ONLY, + :MPI_ERR_RMA_CONFLICT, + :MPI_ERR_RMA_SYNC, + :MPI_ERR_SERVICE, + :MPI_ERR_SIZE, + :MPI_ERR_SPAWN, + :MPI_ERR_UNSUPPORTED_DATAREP, + :MPI_ERR_UNSUPPORTED_OPERATION, + :MPI_ERR_WIN, + # :MPI_T_ERR_MEMORY, + # :MPI_T_ERR_NOT_INITIALIZED, + # :MPI_T_ERR_CANNOT_INIT, + # :MPI_T_ERR_INVALID_INDEX, + # :MPI_T_ERR_INVALID_ITEM, + # :MPI_T_ERR_INVALID_HANDLE, + # :MPI_T_ERR_OUT_OF_HANDLES, + # :MPI_T_ERR_OUT_OF_SESSIONS, + # :MPI_T_ERR_INVALID_SESSION, + # :MPI_T_ERR_CVAR_SET_NOT_NOW, + # :MPI_T_ERR_CVAR_SET_NEVER, + # :MPI_T_ERR_PVAR_NO_STARTSTOP, + # :MPI_T_ERR_PVAR_NO_WRITE, + # :MPI_T_ERR_PVAR_NO_ATOMIC, + :MPI_ERR_RMA_RANGE, + :MPI_ERR_RMA_ATTACH, + :MPI_ERR_RMA_FLAVOR, + :MPI_ERR_RMA_SHARED, + # :MPI_T_ERR_INVALID, + # :MPI_T_ERR_INVALID_NAME, + # :MPI_ERR_PROC_ABORTED, + # :MPI_ERR_PROC_FAILED, + # :MPI_ERR_PROC_FAILED_PENDING, + # :MPI_ERR_REVOKED, + ] + !isdefined(MPI.API, name) && continue + value = getproperty(MPI.API, name) + if value isa Base.RefValue + value = value[] + end + value = convert(Int, value) + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, value::Int)::Cvoid + end +end + +mutable struct TracedRequest <: MPI.AbstractRequest + paths::Tuple + mlir_data::Union{Nothing,Reactant.MLIR.IR.Value} + + function TracedRequest(paths::Tuple, mlir_data::Union{Nothing,Reactant.MLIR.IR.Value}) + if !isnothing(mlir_data) + @assert size(Reactant.MLIR.IR.type(mlir_data)) == () + end + return new(paths, mlir_data) + end +end + +function Base.show(io::IOty, X::TracedRequest) where {IOty<:Union{IO,IOContext}} + return print(io, "TracedRequest(", X.paths, ")") +end + +# # NOTE: Commenting out the below on the assumption that a Request will never cross the compile boundary +# # If we ever want to return a request, the below could serve as a starting point +# Reactant.TracedUtils.get_mlir_data(x::TracedRequest) = x.mlir_data +# Reactant.TracedUtils.set_mlir_data!(x::TracedRequest, data) = (x.mlir_data = data; return x) + +# Reactant.TracedUtils.get_paths(x::TracedRequest) = x.paths +# Reactant.TracedUtils.set_paths!(x::TracedRequest, paths) = (x.paths = paths; return x) +# +# function Reactant.Ops.mlir_type(x::TracedRequest)::MLIR.IR.Type +# # return MLIR.IR.TensorType(collect(Int, size(x)), MLIR.IR.Type(unwrapped_eltype(x))) +# return MLIR.IR.TensorType(collect(Int, ()), MLIR.IR.Type(Int64)) +# end +# +# TODO for this to work properly in finalize_mlir_fn(), need to add TracedRequest to TracedTypes, currently const +# Base.@nospecializeinfer function Reactant.make_tracer( +# seen, +# @nospecialize(prev::TracedRequest), +# @nospecialize(path), +# mode; +# tobatch=nothing, +# toscalar=false, +# @nospecialize(sharding = Sharding.NoSharding()), +# @nospecialize(runtime = nothing), +# kwargs..., +# ) +# if mode == Reactant.NoStopTracedTrack +# Reactant.TracedUtils.set_paths!(prev, (Reactant.TracedUtils.get_paths(prev)..., path)) +# if !haskey(seen, prev) +# seen[prev] = prev # don't return! +# end +# return prev +# end +# if mode == Reactant.TracedToConcrete +# haskey(seen, prev) && return seen[prev]::MPI.Request +# if !Sharding.is_sharded(sharding) +# res = MPI.Request() +# else +# error("Attempting to use sharding and MPI simultaneously") +# end +# seen[prev] = res +# return res +# end +# throw("Trace mode $mode not implemented") +# end +# +# function Reactant.Compiler.create_result( +# tocopy::MPI.Request, +# path, +# result_stores, +# path_to_shard_info, +# to_unreshard_results, +# unresharded_code::Vector{Expr}, +# unresharded_arrays_cache, +# used_shardinfo, +# result_cache, +# var_idx, +# resultgen_code, +# ) +# if !haskey(result_cache, tocopy) +# sym = Symbol("result", var_idx[]) +# var_idx[] += 1 +# +# @assert haskey(result_stores, path) +# restore = result_stores[path] +# delete!(result_stores, path) +# if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) +# error("Attempting to use sharding and MPI simultaneously") +# else +# # TODO +# # restore = result_buffer1 = linearized_results[1] = result of XLA.executesharded() +# # but what is actually returned from XLA.executesharded? +# # Same thing as returned from MPI.Isend (ie, TracedRequest)? +# result = :(MPI.Request($restore)) +# end +# push!( +# resultgen_code, +# quote +# $sym = $result +# end, +# ) +# result_cache[tocopy] = sym +# end +# +# return result_cache[tocopy] +# end + +include("Ops.jl") +include("Overrides.jl") + +end # module diff --git a/test/Project.toml b/test/Project.toml index 05b1e8906a..ebb2029081 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,6 +19,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MethodAnalysis = "85b6ec6f-f7df-4429-9514-a64bcd9ee824" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" @@ -53,6 +54,7 @@ KernelAbstractions = "0.9.30" LinearAlgebra = "1.10" Lux = "1.21" LuxLib = "1.11" +MPI = "0.20" NNlib = "0.9.26" OffsetArrays = "1" OneHotArrays = "0.2.6" diff --git a/test/integration/mpi.jl b/test/integration/mpi.jl new file mode 100644 index 0000000000..b7e6859dbe --- /dev/null +++ b/test/integration/mpi.jl @@ -0,0 +1,138 @@ +using Test, MPI, Reactant + +client = Reactant.XLA.default_backend() +Reactant.set_default_backend("cpu") + +MPI.Init() + +@testset "Comm_rank" begin + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + @test rank == @jit MPI.Comm_rank(comm) +end + +@testset "Comm_size" begin + comm = MPI.COMM_WORLD + nranks = MPI.Comm_size(comm) + @test nranks == @jit MPI.Comm_size(comm) +end + +@testset "Allreduce" begin + comm = MPI.COMM_WORLD + x = ConcreteRArray(fill(1)) + nranks = MPI.Comm_size(comm) + @test nranks == @jit MPI.Allreduce(x, MPI.SUM, MPI.COMM_WORLD) +end + +@testset "Barrier" begin + @testset "Single Barrier" begin + comm = MPI.COMM_WORLD + ret = @jit MPI.Barrier(comm) + @test ret === nothing + end + + @testset "Consecutive Barriers" begin + comm = MPI.COMM_WORLD + for i in 1:3 + @test_nowarn @jit MPI.Barrier(comm) + end + end +end + +@testset "Send / Recv!" begin + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + nranks = MPI.Comm_size(comm) + + # # useful for isolating whether Reactant Send or Recv! is the issue + # @testset "MPI.jl Send / Reactant Recv!" begin + # send_buf = ones(5) + # tag = 43 + # if rank == 0 + # MPI.Send(send_buf, comm; dest=1, tag=tag) + # elseif rank == 1 + # recv_buf = ConcreteRArray(zeros(5)) + # source = 0 + # @jit MPI.Recv!(recv_buf, source, tag, comm) + # @test recv_buf == send_buf + # end + # end + # @testset "Reactant Send / MPI.jl Recv!" begin + # send_buf = ConcreteRArray(ones(5)) + # tag = 43 + # if rank == 0 + # dest = 1 + # @jit MPI.Send(send_buf, dest, tag, comm) + # elseif rank == 1 + # recv_buf = zeros(5) + # MPI.Recv!(recv_buf, comm; source=0, tag=tag) + # @test recv_buf == send_buf + # end + # end + + # test Reactant Send/Recv + @testset "Reactant Send / Recv! - compiled separately" begin + send_buf = ConcreteRArray(ones(5)) + tag = 43 + if rank == 0 + dest = 1 + @jit MPI.Send(send_buf, dest, tag, comm) + elseif rank == 1 + recv_buf = ConcreteRArray(zeros(5)) + src = 0 + @jit MPI.Recv!(recv_buf, src, tag, comm) + @test recv_buf == send_buf + end + end + + @testset "Reactant Send / Recv! - compiled together" begin + send_buf = ConcreteRArray(ones(5)) + recv_buf = ConcreteRArray(zeros(5)) + tag = 43 + function sendrecv!(comm, rank, send_buf, recv_buf, tag) + if rank == 0 + dest = 1 + MPI.Send(send_buf, dest, tag, comm) + return nothing + elseif rank == 1 + src = 0 + MPI.Recv!(recv_buf, src, tag, comm) + return nothing + end + end + @jit sendrecv!(comm, rank, send_buf, recv_buf, tag) + rank == 1 && @test recv_buf == send_buf + end +end + +@testset "Isend / Irecv! / Wait" begin + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + nranks = MPI.Comm_size(comm) + + # NOTE: currently don't allow a request to cross the compile boundary + # debugging tip: if this fails, can use pair Send with Irecv! + Wait, or Recv! with + # Isend + Wait to isolate the issue + send_buf = ConcreteRArray(ones(5)) + recv_buf = ConcreteRArray(zeros(5)) + tag = 42 + function isendirecvwait(send_buf, recv_buf, rank, tag, comm) + if rank == 0 + dest = 1 + req = MPI.Isend(send_buf, dest, tag, comm) + MPI.Wait(req) + return nothing + elseif rank == 1 + src = 0 + req = MPI.Irecv!(recv_buf, src, tag, comm) + MPI.Wait(req) + return nothing + end + end + @jit isendirecvwait(send_buf, recv_buf, rank, tag, comm) + rank == 1 && @test recv_buf == send_buf +end + +MPI.Finalize() + +Reactant.set_default_backend(client) diff --git a/test/runtests.jl b/test/runtests.jl index 98a02a7de0..5edaa478e5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,11 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") @safetestset "Zygote" include("integration/zygote.jl") + @safetestset "MPI" begin + using MPI + nranks = 2 + run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`) + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"