Skip to content

Commit bfe67a9

Browse files
committed
ops
1 parent 144ccdd commit bfe67a9

File tree

3 files changed

+198
-0
lines changed

3 files changed

+198
-0
lines changed

ext/ReactantMPIExt/Ops.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
module Ops
2+
using Reactant: TracedRArray, TracedRNumber
3+
using Reactant: MLIR
4+
using Reactant.MLIR.Dialects: mpi
5+
using ..ReactantMPIExt: TracedRequest
6+
7+
# TODO add communicators
8+
9+
function init(; location=mlir_stacktrace("mpi.init", @__FILE__, @__LINE__))
10+
return mpi.init(; location)
11+
end
12+
13+
function finalize(; location=mlir_stacktrace("mpi.finalize", @__FILE__, @__LINE__))
14+
return mpi.finalize(; location)
15+
end
16+
17+
function comm_rank(; location=mlir_stacktrace("mpi.comm_rank", @__FILE__, @__LINE__))
18+
res = MLIR.IR.result(mpi.comm_rank(; location))
19+
return TracedRNumber{Int}((), res)
20+
end
21+
22+
function comm_size(; location=mlir_stacktrace("mpi.comm_size", @__FILE__, @__LINE__))
23+
res = MLIR.IR.result(mpi.comm_size(; location))
24+
return TracedRNumber{Int}((), res)
25+
end
26+
27+
# TODO should we emit `stablehlo.optimization_barrier` here too?
28+
function barrier(; location=mlir_stacktrace("mpi.barrier", @__FILE__, @__LINE__))
29+
return mpi.barrier(; location)
30+
end
31+
32+
function send(
33+
buf::TracedRArray,
34+
tag::TracedRNumber,
35+
dest::TracedRNumber;
36+
location=mlir_stacktrace("mpi.send", @__FILE__, @__LINE__),
37+
)
38+
return mpi.send(buf.mlir_data, tag.mlir_data, dest.mlir_data; location)
39+
end
40+
41+
# TODO need c-function for creating MLIR `mpi.request` type?
42+
function isend(
43+
buf::TracedRArray,
44+
tag::TracedRNumber,
45+
dest::TracedRNumber;
46+
location=mlir_stacktrace("mpi.isend", @__FILE__, @__LINE__),
47+
)
48+
return TracedRequest(
49+
MLIR.IR.result(mpi.isend(buf.mlir_data, tag.mlir_data, dest.mlir_data; location))
50+
)
51+
end
52+
53+
function recv!(
54+
ref::TracedRArray,
55+
tag::TracedRNumber,
56+
src::TracedRNumber;
57+
location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__),
58+
)
59+
return mpi.recv(ref.mlir_data, tag.mlir_data, src.mlir_data; location)
60+
end
61+
62+
# TODO need c-function for creating MLIR `mpi.request` type?
63+
function irecv!(
64+
ref::TracedRArray,
65+
tag::TracedRNumber,
66+
src::TracedRNumber;
67+
location=mlir_stacktrace("mpi.irecv", @__FILE__, @__LINE__),
68+
)
69+
return TracedRequest(
70+
MLIR.IR.result(mpi.irecv(ref.mlir_data, tag.mlir_data, src.mlir_data; location))
71+
)
72+
end
73+
74+
function wait(
75+
req::TracedRequest; location=mlir_stacktrace("mpi.wait", @__FILE__, @__LINE__)
76+
)
77+
return mpi.wait(req.mlir_data; location)
78+
end
79+
80+
end # module

ext/ReactantMPIExt/Overrides.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
@reactant_overlay @noinline function MPI.Init(; kwargs...)
2+
if !isempty(kwargs)
3+
@warn "Ignoring MPI.Init kwargs when tracing over MPI..." kwargs...
4+
end
5+
return Ops.init()
6+
end
7+
8+
@reactant_overlay @noinline function MPI.Init(; kwargs...)
9+
return Ops.finalize()
10+
end
11+
12+
@reactant_overlay @noinline function MPI.Comm_rank(comm::MPI.Comm)
13+
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
14+
return Ops.comm_rank()
15+
end
16+
17+
@reactant_overlay @noinline function MPI.Comm_size(comm::MPI.Comm)
18+
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
19+
return Ops.comm_size()
20+
end
21+
22+
@reactant_overlay @noinline function MPI.Barrier(comm::MPI.Comm)
23+
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
24+
return Ops.barrier()
25+
end
26+
27+
# TODO status not supported yet
28+
function MPI.Wait(req::TracedRequest)
29+
return Ops.wait(req)
30+
end
31+
32+
# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer`
33+
function MPI.Send(buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm)
34+
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
35+
36+
tag = if !(tag isa TracedRNumber)
37+
Ops.constant(tag)
38+
end
39+
40+
dest = if !(dest isa TracedRNumber)
41+
Ops.constant(dest)
42+
end
43+
44+
return Ops.send(buf, tag, dest)
45+
end
46+
47+
# TODO use `make_tracer` to linearize arbitrary types? check out `MPI.Buffer`
48+
function MPI.Isend(buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm)
49+
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
50+
51+
tag = if !(tag isa TracedRNumber)
52+
Ops.constant(tag)
53+
end
54+
55+
return dest = if !(dest isa TracedRNumber)
56+
Ops.constant(dest)
57+
end
58+
59+
return Ops.isend(buf, tag, dest)
60+
end
61+
62+
# TODO should we error if other `AbstractRequest` types are passed in?
63+
function MPI.Isend(
64+
buf::TracedRArray, dest::Number, tag::Number, comm::MPI.Comm, req::TracedRequest
65+
)
66+
gen_req = MPI.Isend(buf, dest, tag, comm)
67+
req.mlir_data = gen_req.mlir_data
68+
return req
69+
end
70+
71+
# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer`
72+
function MPI.Recv!(
73+
recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm, status
74+
)
75+
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
76+
@assert isnothing(status) "Status not supported yet"
77+
78+
tag = if !(tag isa TracedRNumber)
79+
Ops.constant(tag)
80+
end
81+
82+
source = if !(source isa TracedRNumber)
83+
Ops.constant(source)
84+
end
85+
86+
return Ops.recv(recvbuf, tag, source)
87+
end
88+
89+
# TODO use `make_tracer` to delinearize arbitrary types? check out `MPI.Buffer`
90+
function MPI.IRecv!(recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm)
91+
@assert comm == MPI.COMM_WORLD "Only MPI.COMM_WORLD is supported currently"
92+
93+
tag = if !(tag isa TracedRNumber)
94+
Ops.constant(tag)
95+
end
96+
97+
source = if !(source isa TracedRNumber)
98+
Ops.constant(source)
99+
end
100+
101+
return Ops.irecv!(recvbuf, tag, source)
102+
end
103+
104+
function MPI.IRecv!(
105+
recvbuf::TracedRArray, source::Number, tag::Number, comm::MPI.Comm, req::TracedRequest
106+
)
107+
gen_req = MPI.IRecv!(recvbuf, source, tag, comm)
108+
req.mlir_data = gen_req.mlir_data
109+
return req
110+
end

ext/ReactantMPIExt.jl renamed to ext/ReactantMPIExt/ReactantMPIExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ function __init__()
4444
"MPI_Recv",
4545
"MPI_Isend",
4646
"MPI_Irecv",
47+
"MPI_Barrier",
4748
"MPI_Wait",
4849
"MPI_Request_free",
4950
)
@@ -52,4 +53,11 @@ function __init__()
5253
end
5354
end
5455

56+
struct TracedRequest <: MPI.AbstractRequest
57+
mlir_data::Union{Nothing,MLIR.IR.Value}
5558
end
59+
60+
include("Ops.jl")
61+
include("Overrides.jl")
62+
63+
end # module

0 commit comments

Comments
 (0)