@@ -274,38 +274,14 @@ function inject_mpi_datatype!(datatype)
274
274
end
275
275
end
276
276
277
- function convert_julia_type_to_mpi_datatype (T:: Type )
278
- if T === Bool
279
- MPI. C_BOOL
280
- elseif T === Int8
281
- MPI. INT8_T
282
- elseif T === Int16
283
- MPI. INT16_T
284
- elseif T === Int32
285
- MPI. INT32_T
286
- elseif T === Int64
287
- MPI. INT64_T
288
- elseif T === Float32
289
- MPI. FLOAT
290
- elseif T === Float64
291
- MPI. DOUBLE
292
- elseif T === ComplexF32
293
- MPI. C_FLOAT_COMPLEX
294
- elseif T === ComplexF64
295
- MPI. C_DOUBLE_COMPLEX
296
- else
297
- throw (ArgumentError (" Unknown conversion from $T to a MPI_Datatype" ))
298
- end
299
- end
300
-
301
277
function send (
302
278
buf:: TracedRArray ,
303
279
tag:: TracedRNumber ,
304
280
dest:: TracedRNumber ;
305
281
location= mlir_stacktrace (" mpi.send" , @__FILE__ , @__LINE__ ),
306
282
)
307
283
T = Reactant. unwrapped_eltype (buf)
308
- mpi_datatype = convert_julia_type_to_mpi_datatype (T)
284
+ mpi_datatype = MPI . Datatype (T)
309
285
mpi_datatype_name = inject_mpi_datatype! (mpi_datatype)
310
286
311
287
sym_name = " enzymexla_wrapper_MPI_Send_$(mpi_datatype_name) "
@@ -356,7 +332,7 @@ function isend(
356
332
location= mlir_stacktrace (" mpi.isend" , @__FILE__ , @__LINE__ ),
357
333
)
358
334
T = Reactant. unwrapped_eltype (buf)
359
- mpi_datatype = convert_julia_type_to_mpi_datatype (T)
335
+ mpi_datatype = MPI . Datatype (T)
360
336
mpi_datatype_name = inject_mpi_datatype! (mpi_datatype)
361
337
362
338
sym_name = " enzymexla_wrapper_MPI_Isend_$(mpi_datatype_name) "
@@ -416,7 +392,7 @@ function recv!(
416
392
location= mlir_stacktrace (" mpi.recv" , @__FILE__ , @__LINE__ ),
417
393
)
418
394
T = Reactant. unwrapped_eltype (recvbuf)
419
- mpi_datatype = convert_julia_type_to_mpi_datatype (T)
395
+ mpi_datatype = MPI . Datatype (T)
420
396
mpi_datatype_name = inject_mpi_datatype! (mpi_datatype)
421
397
422
398
sym_name = " enzymexla_wrapper_MPI_Recv_$(mpi_datatype_name) "
@@ -477,7 +453,7 @@ function irecv!(
477
453
location= mlir_stacktrace (" mpi.irecv" , @__FILE__ , @__LINE__ ),
478
454
)
479
455
T = Reactant. unwrapped_eltype (buf)
480
- mpi_datatype = convert_julia_type_to_mpi_datatype (T)
456
+ mpi_datatype = MPI . Datatype (T)
481
457
mpi_datatype_name = inject_mpi_datatype! (mpi_datatype)
482
458
483
459
sym_name = " enzymexla_wrapper_MPI_Irecv_$(mpi_datatype_name) "
@@ -626,7 +602,7 @@ function allreduce!(
626
602
627
603
op_name = inject_mpi_op! (op)
628
604
T = Reactant. unwrapped_eltype (sendbuf)
629
- mpi_datatype = convert_julia_type_to_mpi_datatype (T)
605
+ mpi_datatype = MPI . Datatype (T)
630
606
mpi_datatype_name = inject_mpi_datatype! (mpi_datatype)
631
607
632
608
IR. inject! (" MPI_COMM_WORLD" , " llvm.mlir.global constant @MPI_COMM_WORLD() : !llvm.ptr" )
0 commit comments