121121Components APIs
122122-----------------
123123"""
124+ import shlex
124125from pathlib import Path
125- from typing import Dict , Optional
126+ from typing import Dict , Optional , Iterable
126127
127128import torchx
128129import torchx .specs as specs
@@ -140,7 +141,8 @@ def ddp(
140141 h : Optional [str ] = None ,
141142 j : str = "1x2" ,
142143 env : Optional [Dict [str , str ]] = None ,
143- rdzv_endpoint : str = "etcd-server.default.svc.cluster.local:2379" ,
144+ rdzv_backend : str = "c10d" ,
145+ rdzv_endpoint : Optional [str ] = None ,
144146) -> specs .AppDef :
145147 """
146148 Distributed data parallel style application (one role, multi-replica).
@@ -162,7 +164,8 @@ def ddp(
162164 h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
163165 j: {nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
164166 env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
165- rdzv_endpoint: etcd server endpoint (only matters when nnodes > 1)
167+ rdzv_backend: rendezvous backend (only matters when nnodes > 1)
168+ rdzv_endpoint: rendezvous server endpoint (only matters when nnodes > 1), defaults to rank0 host for schedulers that support it
166169 """
167170
168171 rep = j .split ("x" )
@@ -176,32 +179,60 @@ def ddp(
176179 raise ValueError (f"Invalid format for -j, usage example: 1x4. Given: { j } " )
177180
178181 script_name_noext = Path (script ).stem # script name no extension
182+
183+ if rdzv_endpoint is None :
184+ rdzv_endpoint = _noquote (f"$${ macros .rank0_env } :29500" )
185+
186+ if nnodes == 1 :
187+ rdzv_backend = "c10d"
188+ rdzv_endpoint = "localhost:29500"
189+
190+ cmd = [
191+ "python" ,
192+ "-m" ,
193+ "torch.distributed.run" ,
194+ "--rdzv_backend" ,
195+ rdzv_backend ,
196+ "--rdzv_endpoint" ,
197+ rdzv_endpoint ,
198+ "--rdzv_id" ,
199+ f"{ macros .app_id } " ,
200+ "--nnodes" ,
201+ str (nnodes ),
202+ "--nproc_per_node" ,
203+ str (nproc_per_node ),
204+ script ,
205+ * script_args ,
206+ ]
179207 return specs .AppDef (
180208 name = name or script_name_noext ,
181209 roles = [
182210 specs .Role (
183211 name = script_name_noext ,
184212 image = image ,
185- entrypoint = "python " ,
213+ entrypoint = "bash " ,
186214 num_replicas = nnodes ,
187215 resource = specs .resource (cpu = cpu , gpu = gpu , memMB = memMB , h = h ),
188- args = [
189- "-m" ,
190- "torch.distributed.run" ,
191- "--rdzv_backend" ,
192- ("c10d" if nnodes == 1 else "etcd" ),
193- "--rdzv_endpoint" ,
194- ("localhost:29500" if nnodes == 1 else rdzv_endpoint ),
195- "--rdzv_id" ,
196- f"{ macros .app_id } " ,
197- "--nnodes" ,
198- str (nnodes ),
199- "--nproc_per_node" ,
200- str (nproc_per_node ),
201- script ,
202- * script_args ,
203- ],
216+ args = ["-c" , _args_join (cmd )],
204217 env = env or {},
205218 )
206219 ],
207220 )
221+
222+
223+ def _args_join (args : Iterable [str ]) -> str :
224+ """
225+ _args_join is like shlex.join but if the argument is wrapped in _noquote
226+ it'll not quote that argument.
227+ """
228+ quoted = [arg if isinstance (arg , _noquote ) else shlex .quote (arg ) for arg in args ]
229+ return " " .join (quoted )
230+
231+
232+ class _noquote (str ):
233+ """
234+ _noquote is a wrapper around str that indicates that the argument shouldn't
235+ be passed through shlex.quote.
236+ """
237+
238+ pass
0 commit comments