Skip to content

Commit c6def59

Browse files
committed
Remove function convert_julia_type_to_mpi_datatype(), use MPI.Dataype()
instead
1 parent a4e159f commit c6def59

File tree

1 file changed

+5
-29
lines changed

1 file changed

+5
-29
lines changed

ext/ReactantMPIExt/Ops.jl

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -274,38 +274,14 @@ function inject_mpi_datatype!(datatype)
274274
end
275275
end
276276

277-
function convert_julia_type_to_mpi_datatype(T::Type)
278-
if T === Bool
279-
MPI.C_BOOL
280-
elseif T === Int8
281-
MPI.INT8_T
282-
elseif T === Int16
283-
MPI.INT16_T
284-
elseif T === Int32
285-
MPI.INT32_T
286-
elseif T === Int64
287-
MPI.INT64_T
288-
elseif T === Float32
289-
MPI.FLOAT
290-
elseif T === Float64
291-
MPI.DOUBLE
292-
elseif T === ComplexF32
293-
MPI.C_FLOAT_COMPLEX
294-
elseif T === ComplexF64
295-
MPI.C_DOUBLE_COMPLEX
296-
else
297-
throw(ArgumentError("Unknown conversion from $T to a MPI_Datatype"))
298-
end
299-
end
300-
301277
function send(
302278
buf::TracedRArray,
303279
tag::TracedRNumber,
304280
dest::TracedRNumber;
305281
location=mlir_stacktrace("mpi.send", @__FILE__, @__LINE__),
306282
)
307283
T = Reactant.unwrapped_eltype(buf)
308-
mpi_datatype = convert_julia_type_to_mpi_datatype(T)
284+
mpi_datatype = MPI.Datatype(T)
309285
mpi_datatype_name = inject_mpi_datatype!(mpi_datatype)
310286

311287
sym_name = "enzymexla_wrapper_MPI_Send_$(mpi_datatype_name)"
@@ -356,7 +332,7 @@ function isend(
356332
location=mlir_stacktrace("mpi.isend", @__FILE__, @__LINE__),
357333
)
358334
T = Reactant.unwrapped_eltype(buf)
359-
mpi_datatype = convert_julia_type_to_mpi_datatype(T)
335+
mpi_datatype = MPI.Datatype(T)
360336
mpi_datatype_name = inject_mpi_datatype!(mpi_datatype)
361337

362338
sym_name = "enzymexla_wrapper_MPI_Isend_$(mpi_datatype_name)"
@@ -416,7 +392,7 @@ function recv!(
416392
location=mlir_stacktrace("mpi.recv", @__FILE__, @__LINE__),
417393
)
418394
T = Reactant.unwrapped_eltype(recvbuf)
419-
mpi_datatype = convert_julia_type_to_mpi_datatype(T)
395+
mpi_datatype = MPI.Datatype(T)
420396
mpi_datatype_name = inject_mpi_datatype!(mpi_datatype)
421397

422398
sym_name = "enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name)"
@@ -477,7 +453,7 @@ function irecv!(
477453
location=mlir_stacktrace("mpi.irecv", @__FILE__, @__LINE__),
478454
)
479455
T = Reactant.unwrapped_eltype(buf)
480-
mpi_datatype = convert_julia_type_to_mpi_datatype(T)
456+
mpi_datatype = MPI.Datatype(T)
481457
mpi_datatype_name = inject_mpi_datatype!(mpi_datatype)
482458

483459
sym_name = "enzymexla_wrapper_MPI_Irecv_$(mpi_datatype_name)"
@@ -626,7 +602,7 @@ function allreduce!(
626602

627603
op_name = inject_mpi_op!(op)
628604
T = Reactant.unwrapped_eltype(sendbuf)
629-
mpi_datatype = convert_julia_type_to_mpi_datatype(T)
605+
mpi_datatype = MPI.Datatype(T)
630606
mpi_datatype_name = inject_mpi_datatype!(mpi_datatype)
631607

632608
IR.inject!("MPI_COMM_WORLD", "llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr")

0 commit comments

Comments
 (0)