@@ -8,10 +8,16 @@ function initialize(;
8
8
coordinator_address:: Union{Nothing,String} = nothing ,
9
9
num_processes:: Union{Nothing,Integer} = nothing ,
10
10
process_id:: Union{Nothing,Integer} = nothing ,
11
+ single_gpu_per_process:: Bool = true ,
11
12
local_gpu_device_ids:: Union{Nothing,Vector{Int}} = nothing ,
12
13
initialization_timeout_in_seconds:: Integer = 300 ,
13
14
kwargs... ,
14
15
)
16
+ if isinteractive ()
17
+ @warn " Reactant.Distributed.initialize() should not be called in interactive mode. \
18
+ Use Reactant.Distributed.initialize() in a script instead."
19
+ end
20
+
15
21
@assert ! initialized[] " `Distributed.initialize` has already been called"
16
22
17
23
(coordinator_address, num_processes, process_id, local_gpu_device_ids) = auto_detect_unset_distributed_params (;
@@ -20,6 +26,7 @@ function initialize(;
20
26
process_id,
21
27
local_gpu_device_ids,
22
28
initialization_timeout_in_seconds,
29
+ single_gpu_per_process,
23
30
)
24
31
25
32
@debug " Detected Reactant distributed params" coordinator_address num_processes process_id local_gpu_device_ids
@@ -43,6 +50,8 @@ struct OpenMPIPMIXEnvDetector <: AbstractOMPIClusterEnvDetector end
43
50
44
51
struct MPIEnvDetector <: AbstractClusterEnvDetector end
45
52
53
+ struct SlurmEnvDetector <: AbstractClusterEnvDetector end
54
+
46
55
# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/cluster.py
47
56
48
57
is_env_present (:: AbstractClusterEnvDetector ) = false
@@ -53,12 +62,19 @@ function get_process_id end
53
62
function get_local_process_id end
54
63
55
64
function auto_detect_unset_distributed_params (;
56
- detector_list= [OpenMPIORTEEnvDetector (), OpenMPIPMIXEnvDetector (), MPIEnvDetector ()],
65
+ detector_list= [
66
+ SlurmEnvDetector (),
67
+ OpenMPIORTEEnvDetector (),
68
+ MPIEnvDetector (),
69
+ # Keep this at the end since parsing for this is a bit flaky
70
+ OpenMPIPMIXEnvDetector (),
71
+ ],
57
72
coordinator_address:: Union{Nothing,String} = nothing ,
58
73
num_processes:: Union{Nothing,Integer} = nothing ,
59
74
process_id:: Union{Nothing,Integer} = nothing ,
60
75
local_gpu_device_ids:: Union{Nothing,Vector{Int}} = nothing ,
61
76
initialization_timeout_in_seconds:: Integer = 300 ,
77
+ single_gpu_per_process:: Bool = true ,
62
78
)
63
79
if all (
64
80
Base. Fix2 (!= = , nothing ),
@@ -91,7 +107,7 @@ function auto_detect_unset_distributed_params(;
91
107
process_id = get_process_id (detector)
92
108
end
93
109
94
- if local_gpu_device_ids === nothing
110
+ if local_gpu_device_ids === nothing && single_gpu_per_process
95
111
local_gpu_device_ids = [get_local_process_id (detector)]
96
112
end
97
113
@@ -108,16 +124,18 @@ const _PMIX_SERVER_URI = (
108
124
" PMIX_SERVER_URI41" ,
109
125
" PMIX_SERVER_URI21" ,
110
126
)
127
+ const _PMIX_NAMESPACE = " PMIX_NAMESPACE"
128
+ const _PRTERUN = " PRTE_LAUNCHED"
129
+ const _PMIX_VERSION = " PMIX_VERSION"
111
130
const _OMPI_PROCESS_COUNT = " OMPI_COMM_WORLD_SIZE"
112
131
const _OMPI_PROCESS_ID = " OMPI_COMM_WORLD_RANK"
113
132
const _OMPI_LOCAL_PROCESS_ID = " OMPI_COMM_WORLD_LOCAL_RANK"
114
133
115
134
is_env_present (:: OpenMPIORTEEnvDetector ) = haskey (ENV , _ORTE_URI)
116
- is_env_present (:: OpenMPIPMIXEnvDetector ) = any (Base . Fix1 ( haskey, ENV ), _PMIX_SERVER_URI )
135
+ is_env_present (:: OpenMPIPMIXEnvDetector ) = haskey ( ENV , _PMIX_NAMESPACE )
117
136
118
137
function get_coordinator_address (:: OpenMPIORTEEnvDetector , :: Integer )
119
138
orte_uri = ENV [_ORTE_URI]
120
-
121
139
job_id = parse (Int, split (orte_uri, ' .' ; limit= 2 )[1 ])
122
140
port = job_id % 2 ^ 12 + (65535 - 2 ^ 12 + 1 )
123
141
@@ -132,11 +150,48 @@ function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer)
132
150
return " $(launcher_ip) :$(port) "
133
151
end
134
152
153
+ function _throw_pmix_env_error (msg)
154
+ msg = msg * " Open an issue on Reactant with the relevant PMIX Enviroment Variables \
155
+ (you might want to obfuscate identifiable variables from this log \
156
+ before opening an issue)\n\n "
157
+ for (var, val) in [var => val for (var, val) in ENV if startswith (var, " PMIX" )]
158
+ msg *= " * $var => $val .\n "
159
+ end
160
+ return error (msg)
161
+ end
162
+
135
163
function get_coordinator_address (:: OpenMPIPMIXEnvDetector , :: Integer )
136
- varname = findfirst (Base. Fix1 (haskey, ENV ), _PMIX_SERVER_URI)
137
- pmix_uri = ENV [_PMIX_SERVER_URI[varname]]
164
+ pmix_version = parse (VersionNumber, ENV [_PMIX_VERSION])
165
+ pmix_uri = ENV [_PMIX_SERVER_URI[findfirst (Base. Fix1 (haskey, ENV ), _PMIX_SERVER_URI)]]
166
+ @debug " PMIX VERSION: $(pmix_version) "
167
+ if v " 5" ≤ pmix_version < v " 6"
168
+ return get_coordinator_address_pmixv5 (pmix_uri)
169
+ elseif v " 2" ≤ pmix_version < v " 4"
170
+ return get_coordinator_address_pmixv2_or_3 (pmix_uri)
171
+ else
172
+ _throw_pmix_env_error (" Unsupported PMIX version: $(pmix_version) ." )
173
+ end
174
+ end
175
+
176
+ function get_coordinator_address_pmixv2_or_3 (pmix_uri)
177
+ pre_semicolon = first (split (pmix_uri, " ;" ))
178
+ if startswith (pre_semicolon, " pmix-server." )
179
+ job_id = parse (Int, first (split (last (split (pre_semicolon, ' .' ; limit= 2 )))))
180
+ elseif contains (pre_semicolon, " ." )
181
+ job_id = parse (Int, first (split (pre_semicolon, ' .' )))
182
+ else
183
+ _throw_pmix_env_error (" Could not parse coordinator address from Open MPI \
184
+ environment." )
185
+ end
186
+ return get_coordinator_address_from_pmix_uri (pmix_uri, job_id)
187
+ end
138
188
139
- job_id = parse (Int, split (split (pmix_uri, ' -' ; limit= 3 )[3 ], " @" ; limit= 2 )[1 ])
189
+ function get_coordinator_address_pmixv5 (pmix_uri)
190
+ job_id = parse (Int, first (split (last (split (pmix_uri, ' -' ; limit= 3 )), " @" ; limit= 2 )))
191
+ return get_coordinator_address_from_pmix_uri (pmix_uri, job_id)
192
+ end
193
+
194
+ function get_coordinator_address_from_pmix_uri (pmix_uri, job_id)
140
195
port = job_id % 2 ^ 12 + (65535 - 2 ^ 12 + 1 )
141
196
142
197
launcher_ip_match = match (r" tcp4://(.+?):|tcp6://\[ (.+?)\] " , pmix_uri)
@@ -159,4 +214,45 @@ function get_local_process_id(::AbstractOMPIClusterEnvDetector)
159
214
return parse (Int, ENV [_OMPI_LOCAL_PROCESS_ID])
160
215
end
161
216
217
+ # SlurmEnvDetector
218
+ # Based on https://github.com/jax-ml/jax/blob/d89835acbacec938971400d6fa54ea6dd5efe76c/jax/_src/clusters/slurm_cluster.py#L3
219
+ const _SLURM_JOB_ID = " SLURM_JOB_ID"
220
+ const _SLURM_NODELIST = " SLURM_STEP_NODELIST"
221
+ const _SLURM_PROCESS_COUNT = " SLURM_NTASKS"
222
+ const _SLURM_PROCESS_ID = " SLURM_PROCID"
223
+ const _SLURM_LOCAL_PROCESS_ID = " SLURM_LOCALID"
224
+ const _SLURM_NUM_NODES = " SLURM_STEP_NUM_NODES"
225
+
226
+ is_env_present (:: SlurmEnvDetector ) = haskey (ENV , _SLURM_JOB_ID)
227
+
228
+ function get_coordinator_address (:: SlurmEnvDetector , :: Integer )
229
+ port = parse (Int, ENV [_SLURM_JOB_ID]) % 2 ^ 12 + (65535 - 2 ^ 12 + 1 )
230
+
231
+ # Parse the first hostname of the job
232
+ # If we are looking for 'node001',
233
+ # node_list potential formats are 'node001', 'node001,host2',
234
+ # 'node[001-0015],host2', and 'node[001,007-015],host2'.
235
+ node_list = ENV [_SLURM_NODELIST]
236
+ ind = findfirst (Base. Fix2 (in, (' ,' , ' [' )), node_list)
237
+ ind = isnothing (ind) ? length (node_list) + 1 : ind
238
+
239
+ if ind == length (node_list) + 1 || node_list[ind] == ' ,'
240
+ # 'node001' or 'node001,host2'
241
+ return " $(node_list[1 : ind- 1 ]) :$(port) "
242
+ else
243
+ # 'node[001-0015],host2' or 'node[001,007-015],host2'
244
+ prefix = node_list[1 : (ind - 1 )]
245
+ suffix = node_list[(ind + 1 ): end ]
246
+ ind2 = findfirst (Base. Fix2 (in, (' ,' , ' -' )), suffix)
247
+ ind2 = isnothing (ind2) ? length (suffix) : ind2
248
+ return " $(prefix)$(suffix[1 : ind2- 1 ]) :$(port) "
249
+ end
250
+ end
251
+
252
+ get_process_count (:: SlurmEnvDetector ) = parse (Int, ENV [_SLURM_PROCESS_COUNT])
253
+
254
+ get_process_id (:: SlurmEnvDetector ) = parse (Int, ENV [_SLURM_PROCESS_ID])
255
+
256
+ get_local_process_id (:: SlurmEnvDetector ) = parse (Int, ENV [_SLURM_LOCAL_PROCESS_ID])
257
+
162
258
end
0 commit comments