Skip to content

Commit 144ccdd

Browse files
committed
Register MPI symbols on load
1 parent 2d4da3f commit 144ccdd

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

ext/ReactantMPIExt.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module ReactantMPIExt
22

3+
using Reactant
34
using Reactant: Reactant, Distributed
45
using MPI: MPI
56

@@ -33,4 +34,22 @@ function Distributed.get_local_process_id(::Distributed.MPIEnvDetector)
3334
return Int(MPI.Comm_rank(new_comm))
3435
end
3536

37+
function __init__()
38+
for name in (
39+
"MPI_Init",
40+
"MPI_Finalize",
41+
"MPI_Comm_rank",
42+
"MPI_Comm_size",
43+
"MPI_Send",
44+
"MPI_Recv",
45+
"MPI_Isend",
46+
"MPI_Irecv",
47+
"MPI_Wait",
48+
"MPI_Request_free",
49+
)
50+
sym = Libdl.dlsym(MPI.API.libmpi, name)
51+
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(name::Cstring, sym::Ptr{Cvoid})::Cvoid
52+
end
53+
end
54+
3655
end

0 commit comments

Comments
 (0)