-
Notifications
You must be signed in to change notification settings - Fork 28
Support MPI #752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Support MPI #752
Conversation
you won't, instead you'll emit something like
And then lower-jit will convert into a custom call. however you will need to define a lowering of mpi.send into a corresponding MPI_Send call [which will use the symbol you just registered here] Re CUDA though we also need to ensure we are sync'd wrt the current custream which you can get via enzymexla.get_stream |
mmm from our last discussion on this a couple of weeks ago, i understood that we would emit this function main() {
...
mpi.send(%arg0, ...)
...
} and it would get lowered to function send_wrap(%arg : memref<axb>) {
llvm.call <0xffff> (%arg)
}
function main() {
...
enzymexla.jit_call @send_wrap(%x : tensor<...>)
...
} which will finally lower to the following with the enzymexla.jit pass function main() {
...
stablehlo.custom_call @mpi_send_wrap(%x : tensor<...>)
...
} is this correct or do we need to emit the ahh or do you mean that any wrapping we need to do around MPI should be done in this way?
okay, this will probably be required for NCCL |
Co-authored-by: Paul Berg <[email protected]>
right way/place to do it
@wsmoses are there any other features we want to add before merging? If not, this might be ready for review |
using Libdl | ||
|
||
# https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py | ||
Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@avik-pal can you review this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me
using Test, MPI, Reactant | ||
|
||
# # MPI only works on cpu currently --- is this the right way/place to enforce that? | ||
# Reactant.set_default_backend("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@avik-pal re integration bits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For safety I would do client = Reactant.default_backend(); Reactant.set_default_backend("cpu")
and at the end of the script set the client back with Reactant.set_default_backend(client)
.
Techinically we are running on a separate process so it shouldn't matter but in case during testing (local/ci) we include the file it will make debugging harder.
# # MPI only works on cpu currently --- is this the right way/place to enforce that? | ||
# Reactant.set_default_backend("cpu") | ||
|
||
MPI.Init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand correctly this (and Finalize
) can't be @compile
d at the moment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's right. Although it looks like there are overrides in Overrides.jl. Let me see if I can get these working easily, otherwise maybe we just remove them unless it's a priority
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally lgtm, but I want @avik-pal to look atthe Distributed.is_env_present and related to double check
Sounds good. There are a couple things I'm trying to clean up in the meantime |
tag::Integer, | ||
comm::MPI.Comm | ||
) | ||
function MPI.Recv!(buf::TracedRArray, source::Integer, tag::Integer, comm::MPI.Comm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's JuliaFormatter (I only accepted the suggestions), whether to wrap or not depends on what would be the length of all the arguments on a single line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting alright fair enough
This PR...
unresolved questions
how can we representMPI_Request
withtensor
andstablehlo
types?mmmstablehlo.custom_call
has abackend
attribute that could be useful during lowering; e.g. if we want to lower to NCCL instead of MPI, since both have a similar API, we could potentially add our own custom c-functions that use NCCL but adapt them to MPI-like API@wsmoses can we create@cfunction
s in Julia and pass them to the symbol table? some MPI routines might need a lil bit of adaption and writing them in Julia would be easier, faster (and also, would use the correct symbols from MPI.jl-loaded libmpi)tested
side_effect=true
onenzymexla.jit_call
to do
cc @JBlaschke @hhkit