Skip to content

Commit dcbf3f9

Browse files
authored
feat: slurm detector + multigpu single process (#891)
* feat: slurm detector + multigpu single process * docs: add multi-host docs * fix: move slurm to higher priority * test: add simple tests for cluster_detection * feat: better error messages + use PMIX version number for detection * fix: check for ifrt_array_copy_to_host_buffer * test: check which failed * fix: condition * fix: warn if users are running Distributed in interactive mode
1 parent 74e8098 commit dcbf3f9

File tree

11 files changed

+317
-19
lines changed

11 files changed

+317
-19
lines changed

docs/make.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ pages = [
3232
"Getting Started" => "introduction/index.md",
3333
"Configuration" => "introduction/configuration.md",
3434
],
35-
"Tutorials" =>
36-
["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"],
35+
"Tutorials" => [
36+
"Overview" => "tutorials/index.md",
37+
"Profiling" => "tutorials/profiling.md",
38+
"Distributed" => "tutorials/multihost.md",
39+
],
3740
"API Reference" => [
3841
"Reactant API" => "api/api.md",
3942
"Ops" => "api/ops.md",

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ export default defineConfig({
6363
items: [
6464
{text: "Overview", link: "/tutorials/"},
6565
{text: "Profiling", link: "/tutorials/profiling"},
66+
{text: "Distributed", link: "/tutorials/multihost"},
6667
],
6768
},
6869
{
@@ -122,6 +123,7 @@ export default defineConfig({
122123
items: [
123124
{ text: "Overview", link: "/tutorials/" },
124125
{ text: "Profiling", link: "/tutorials/profiling" },
126+
{ text: "Distributed", link: "/tutorials/multihost" },
125127
],
126128
},
127129
"/api/": {

docs/src/api/sharding.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
CollapsedDocStrings = true
33
```
44

5-
# Sharding API
5+
# [Sharding API](@id sharding-api)
66

77
`Reactant.Sharding` module provides a high-level API to construct MLIR operations with
88
support for sharding.

docs/src/tutorials/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Tutorials
22

33
- [Profiling](@ref profiling).
4+
- [Multi-Host Environments](@ref distributed).
45

56
We are currently working on adding more tutorials to Reactant!! Please check back soon!

docs/src/tutorials/multihost.md

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# [Multi-Host Environments](@ref distributed)
2+
3+
!!! tip "Use XLA IFRT Runtime"
4+
5+
While PJRT does support some minimal distributed capabilities on CUDA GPUs, distributed
6+
support in Reactant is primarily provided via IFRT. Before loading Reactant, set the
7+
"xla_runtime" preference to be "IFRT". This can be done with:
8+
9+
```julia
10+
using Preferences, UUIDs
11+
12+
Preferences.set_preference!(
13+
UUID("3c362404-f566-11ee-1572-e11a4b42c853"),
14+
"xla_runtime" => "IFRT"
15+
)
16+
```
17+
18+
At the top of your code, just after loading Reactant and before running any Reactant related
19+
operations, run `Reactant.Distributed.initialize()`.
20+
21+
!!! tip "Enable debug logging for debugging"
22+
23+
Reactant emits a lot of useful debugging information when setting up the Distributed
24+
Runtime. This can be printing by setting the env var `JULIA_DEBUG` to contain
25+
`Reactant`.
26+
27+
After this simply setup your code with [`Reactant.Sharding`](@ref sharding-api) and the code
28+
will run on multiple devices across multiple nodes.
29+
30+
## Example Slurm Script for Multi-Host Matrix Multiplication
31+
32+
::: code-group
33+
34+
```bash [main.sbatch]
35+
#!/bin/bash -l
36+
#
37+
#SBATCH --job-name=matmul-sharding-reactant
38+
#SBATCH --time=00:20:00
39+
#SBATCH --nodes=2
40+
#SBATCH --ntasks-per-node=1
41+
#SBATCH --cpus-per-task=72
42+
#SBATCH --account=<account>
43+
#SBATCH --constraint=gpu
44+
45+
export JULIA_DEBUG="Reactant,Reactant_jll"
46+
47+
srun --preserve-env bash ./matmul.sh
48+
```
49+
50+
```bash [matmul.sh]
51+
#!/bin/bash -l
52+
53+
# Important else XLA might hang indefinitely
54+
unset no_proxy http_proxy https_proxy NO_PROXY HTTP_PROXY HTTPS_PROXY
55+
56+
julia --project=. --threads=auto matmul_sharded.jl
57+
```
58+
59+
```julia [matmul_sharded.jl]
60+
using Reactant
61+
62+
Reactant.Distributed.initialize(; single_gpu_per_process=false)
63+
64+
@assert length(Reactant.devices()) >= 2
65+
66+
N = min((length(Reactant.devices()) ÷ 2) * 2, 8)
67+
68+
mesh = Sharding.Mesh(reshape(Reactant.devices()[1:N], 2, :), (:x, :y))
69+
sharding = Sharding.NamedSharding(mesh, (:x, :y))
70+
71+
x = reshape(collect(Float32, 1:64), 8, 8)
72+
y = reshape(collect(Float32, 1:64), 8, 8)
73+
74+
x_ra = Reactant.to_rarray(x; sharding)
75+
y_ra = Reactant.to_rarray(y; sharding)
76+
77+
res = @jit x_ra * y_ra
78+
79+
display(res)
80+
```
81+
82+
:::

src/Distributed.jl

Lines changed: 103 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@ function initialize(;
88
coordinator_address::Union{Nothing,String}=nothing,
99
num_processes::Union{Nothing,Integer}=nothing,
1010
process_id::Union{Nothing,Integer}=nothing,
11+
single_gpu_per_process::Bool=true,
1112
local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing,
1213
initialization_timeout_in_seconds::Integer=300,
1314
kwargs...,
1415
)
16+
if isinteractive()
17+
@warn "Reactant.Distributed.initialize() should not be called in interactive mode. \
18+
Use Reactant.Distributed.initialize() in a script instead."
19+
end
20+
1521
@assert !initialized[] "`Distributed.initialize` has already been called"
1622

1723
(coordinator_address, num_processes, process_id, local_gpu_device_ids) = auto_detect_unset_distributed_params(;
@@ -20,6 +26,7 @@ function initialize(;
2026
process_id,
2127
local_gpu_device_ids,
2228
initialization_timeout_in_seconds,
29+
single_gpu_per_process,
2330
)
2431

2532
@debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_gpu_device_ids
@@ -43,6 +50,8 @@ struct OpenMPIPMIXEnvDetector <: AbstractOMPIClusterEnvDetector end
4350

4451
struct MPIEnvDetector <: AbstractClusterEnvDetector end
4552

53+
struct SlurmEnvDetector <: AbstractClusterEnvDetector end
54+
4655
# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/cluster.py
4756

4857
is_env_present(::AbstractClusterEnvDetector) = false
@@ -53,12 +62,19 @@ function get_process_id end
5362
function get_local_process_id end
5463

5564
function auto_detect_unset_distributed_params(;
56-
detector_list=[OpenMPIORTEEnvDetector(), OpenMPIPMIXEnvDetector(), MPIEnvDetector()],
65+
detector_list=[
66+
SlurmEnvDetector(),
67+
OpenMPIORTEEnvDetector(),
68+
MPIEnvDetector(),
69+
# Keep this at the end since parsing for this is a bit flaky
70+
OpenMPIPMIXEnvDetector(),
71+
],
5772
coordinator_address::Union{Nothing,String}=nothing,
5873
num_processes::Union{Nothing,Integer}=nothing,
5974
process_id::Union{Nothing,Integer}=nothing,
6075
local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing,
6176
initialization_timeout_in_seconds::Integer=300,
77+
single_gpu_per_process::Bool=true,
6278
)
6379
if all(
6480
Base.Fix2(!==, nothing),
@@ -91,7 +107,7 @@ function auto_detect_unset_distributed_params(;
91107
process_id = get_process_id(detector)
92108
end
93109

94-
if local_gpu_device_ids === nothing
110+
if local_gpu_device_ids === nothing && single_gpu_per_process
95111
local_gpu_device_ids = [get_local_process_id(detector)]
96112
end
97113

@@ -108,16 +124,18 @@ const _PMIX_SERVER_URI = (
108124
"PMIX_SERVER_URI41",
109125
"PMIX_SERVER_URI21",
110126
)
127+
const _PMIX_NAMESPACE = "PMIX_NAMESPACE"
128+
const _PRTERUN = "PRTE_LAUNCHED"
129+
const _PMIX_VERSION = "PMIX_VERSION"
111130
const _OMPI_PROCESS_COUNT = "OMPI_COMM_WORLD_SIZE"
112131
const _OMPI_PROCESS_ID = "OMPI_COMM_WORLD_RANK"
113132
const _OMPI_LOCAL_PROCESS_ID = "OMPI_COMM_WORLD_LOCAL_RANK"
114133

115134
is_env_present(::OpenMPIORTEEnvDetector) = haskey(ENV, _ORTE_URI)
116-
is_env_present(::OpenMPIPMIXEnvDetector) = any(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
135+
is_env_present(::OpenMPIPMIXEnvDetector) = haskey(ENV, _PMIX_NAMESPACE)
117136

118137
function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer)
119138
orte_uri = ENV[_ORTE_URI]
120-
121139
job_id = parse(Int, split(orte_uri, '.'; limit=2)[1])
122140
port = job_id % 2^12 + (65535 - 2^12 + 1)
123141

@@ -132,11 +150,48 @@ function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer)
132150
return "$(launcher_ip):$(port)"
133151
end
134152

153+
function _throw_pmix_env_error(msg)
154+
msg = msg * " Open an issue on Reactant with the relevant PMIX Enviroment Variables \
155+
(you might want to obfuscate identifiable variables from this log \
156+
before opening an issue)\n\n"
157+
for (var, val) in [var => val for (var, val) in ENV if startswith(var, "PMIX")]
158+
msg *= " * $var => $val.\n"
159+
end
160+
return error(msg)
161+
end
162+
135163
function get_coordinator_address(::OpenMPIPMIXEnvDetector, ::Integer)
136-
varname = findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
137-
pmix_uri = ENV[_PMIX_SERVER_URI[varname]]
164+
pmix_version = parse(VersionNumber, ENV[_PMIX_VERSION])
165+
pmix_uri = ENV[_PMIX_SERVER_URI[findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)]]
166+
@debug "PMIX VERSION: $(pmix_version)"
167+
if v"5" pmix_version < v"6"
168+
return get_coordinator_address_pmixv5(pmix_uri)
169+
elseif v"2" pmix_version < v"4"
170+
return get_coordinator_address_pmixv2_or_3(pmix_uri)
171+
else
172+
_throw_pmix_env_error("Unsupported PMIX version: $(pmix_version).")
173+
end
174+
end
175+
176+
function get_coordinator_address_pmixv2_or_3(pmix_uri)
177+
pre_semicolon = first(split(pmix_uri, ";"))
178+
if startswith(pre_semicolon, "pmix-server.")
179+
job_id = parse(Int, first(split(last(split(pre_semicolon, '.'; limit=2)))))
180+
elseif contains(pre_semicolon, ".")
181+
job_id = parse(Int, first(split(pre_semicolon, '.')))
182+
else
183+
_throw_pmix_env_error("Could not parse coordinator address from Open MPI \
184+
environment.")
185+
end
186+
return get_coordinator_address_from_pmix_uri(pmix_uri, job_id)
187+
end
138188

139-
job_id = parse(Int, split(split(pmix_uri, '-'; limit=3)[3], "@"; limit=2)[1])
189+
function get_coordinator_address_pmixv5(pmix_uri)
190+
job_id = parse(Int, first(split(last(split(pmix_uri, '-'; limit=3)), "@"; limit=2)))
191+
return get_coordinator_address_from_pmix_uri(pmix_uri, job_id)
192+
end
193+
194+
function get_coordinator_address_from_pmix_uri(pmix_uri, job_id)
140195
port = job_id % 2^12 + (65535 - 2^12 + 1)
141196

142197
launcher_ip_match = match(r"tcp4://(.+?):|tcp6://\[(.+?)\]", pmix_uri)
@@ -159,4 +214,45 @@ function get_local_process_id(::AbstractOMPIClusterEnvDetector)
159214
return parse(Int, ENV[_OMPI_LOCAL_PROCESS_ID])
160215
end
161216

217+
# SlurmEnvDetector
218+
# Based on https://github.com/jax-ml/jax/blob/d89835acbacec938971400d6fa54ea6dd5efe76c/jax/_src/clusters/slurm_cluster.py#L3
219+
const _SLURM_JOB_ID = "SLURM_JOB_ID"
220+
const _SLURM_NODELIST = "SLURM_STEP_NODELIST"
221+
const _SLURM_PROCESS_COUNT = "SLURM_NTASKS"
222+
const _SLURM_PROCESS_ID = "SLURM_PROCID"
223+
const _SLURM_LOCAL_PROCESS_ID = "SLURM_LOCALID"
224+
const _SLURM_NUM_NODES = "SLURM_STEP_NUM_NODES"
225+
226+
is_env_present(::SlurmEnvDetector) = haskey(ENV, _SLURM_JOB_ID)
227+
228+
function get_coordinator_address(::SlurmEnvDetector, ::Integer)
229+
port = parse(Int, ENV[_SLURM_JOB_ID]) % 2^12 + (65535 - 2^12 + 1)
230+
231+
# Parse the first hostname of the job
232+
# If we are looking for 'node001',
233+
# node_list potential formats are 'node001', 'node001,host2',
234+
# 'node[001-0015],host2', and 'node[001,007-015],host2'.
235+
node_list = ENV[_SLURM_NODELIST]
236+
ind = findfirst(Base.Fix2(in, (',', '[')), node_list)
237+
ind = isnothing(ind) ? length(node_list) + 1 : ind
238+
239+
if ind == length(node_list) + 1 || node_list[ind] == ','
240+
# 'node001' or 'node001,host2'
241+
return "$(node_list[1:ind-1]):$(port)"
242+
else
243+
# 'node[001-0015],host2' or 'node[001,007-015],host2'
244+
prefix = node_list[1:(ind - 1)]
245+
suffix = node_list[(ind + 1):end]
246+
ind2 = findfirst(Base.Fix2(in, (',', '-')), suffix)
247+
ind2 = isnothing(ind2) ? length(suffix) : ind2
248+
return "$(prefix)$(suffix[1:ind2-1]):$(port)"
249+
end
250+
end
251+
252+
get_process_count(::SlurmEnvDetector) = parse(Int, ENV[_SLURM_PROCESS_COUNT])
253+
254+
get_process_id(::SlurmEnvDetector) = parse(Int, ENV[_SLURM_PROCESS_ID])
255+
256+
get_local_process_id(::SlurmEnvDetector) = parse(Int, ENV[_SLURM_LOCAL_PROCESS_ID])
257+
162258
end

src/xla/Distributed.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function update!(
129129
coordinator_address::String,
130130
num_processes::Int,
131131
process_id::Int,
132-
local_gpu_device_ids::Vector{Int},
132+
local_gpu_device_ids::Union{Nothing,Vector{Int}},
133133
coordinator_bind_address::Union{Nothing,String}=nothing,
134134
cluster_register_timeout_in_minutes::Integer=60,
135135
rpc_timeout_in_seconds::Integer=120,
@@ -141,7 +141,9 @@ function update!(
141141
@assert 0 process_id < num_processes
142142

143143
state.coordinator_address = coordinator_address
144-
state.local_gpu_device_ids = local_gpu_device_ids
144+
if local_gpu_device_ids !== nothing
145+
state.local_gpu_device_ids = local_gpu_device_ids
146+
end
145147
state.process_id = process_id
146148
state.num_processes = num_processes
147149

src/xla/IFRT/Array.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ function XLA.buffer_on_cpu(::Array)
138138
end
139139

140140
function XLA.to_host(buffer::Array, data, reactant_sharding)
141-
if length(XLA.devices(XLA.sharding(buffer))) == 1
141+
reactant_sharding = Reactant.Sharding.unwrap_shardinfo(reactant_sharding)
142+
143+
if reactant_sharding isa Reactant.Sharding.NoSharding
142144
GC.@preserve buffer data begin
143145
@ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer(
144146
buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}
@@ -147,7 +149,6 @@ function XLA.to_host(buffer::Array, data, reactant_sharding)
147149
return data
148150
end
149151

150-
reactant_sharding = Reactant.Sharding.unwrap_shardinfo(reactant_sharding)
151152
@assert reactant_sharding isa Reactant.Sharding.HloSharding
152153
client = XLA.client(buffer)
153154
all_devices = XLA.get_device.((client,), reactant_sharding.mesh.device_ids)

0 commit comments

Comments
 (0)