121121Components APIs
122122-----------------
123123"""
124+ import shlex
124125from pathlib import Path
125- from typing import Dict , Optional
126+ from typing import Dict , Optional , Iterable
126127
127128import torchx
128129import torchx .specs as specs
131132
132133def ddp (
133134 * script_args : str ,
134- script : str ,
135+ script : Optional [str ] = None ,
136+ m : Optional [str ] = None ,
135137 image : str = torchx .IMAGE ,
136138 name : Optional [str ] = None ,
139+ h : Optional [str ] = None ,
137140 cpu : int = 2 ,
138141 gpu : int = 0 ,
139142 memMB : int = 1024 ,
140- h : Optional [str ] = None ,
141143 j : str = "1x2" ,
142144 env : Optional [Dict [str , str ]] = None ,
143- rdzv_endpoint : str = "etcd-server.default.svc.cluster.local:2379" ,
145+ max_restarts : Optional [int ] = None ,
146+ rdzv_backend : str = "c10d" ,
147+ rdzv_endpoint : Optional [str ] = None ,
144148) -> specs .AppDef :
145149 """
146150 Distributed data parallel style application (one role, multi-replica).
@@ -154,6 +158,7 @@ def ddp(
154158 Args:
155159 script_args: arguments to the main module
156160 script: script or binary to run within the image
161+ m: the python module path to run
157162 image: image (e.g. docker)
158163 name: job name override (uses the script name if not specified)
159164 cpu: number of cpus per replica
@@ -162,9 +167,14 @@ def ddp(
162167 h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
163168 j: {nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
164169 env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
165- rdzv_endpoint: etcd server endpoint (only matters when nnodes > 1)
170+ max_restarts: the number of restarts allowed
171+ rdzv_backend: rendezvous backend (only matters when nnodes > 1)
172+ rdzv_endpoint: rendezvous server endpoint (only matters when nnodes > 1), defaults to rank0 host for schedulers that support it
166173 """
167174
175+ if (script is None ) == (m is None ):
176+ raise ValueError ("exactly one of --script and -m must be specified" )
177+
168178 rep = j .split ("x" )
169179 if len (rep ) == 1 : # num replicas only
170180 nnodes = 1
@@ -175,33 +185,79 @@ def ddp(
175185 else :
176186 raise ValueError (f"Invalid format for -j, usage example: 1x4. Given: { j } " )
177187
178- script_name_noext = Path (script ).stem # script name no extension
188+ if script :
189+ # script name/module no extension
190+ role_name = Path (script ).stem
191+ elif m :
192+ role_name = m .rpartition ("." )[2 ]
193+ else :
194+ raise ValueError ("failed to compute role_name" )
195+
196+ if rdzv_endpoint is None :
197+ rdzv_endpoint = _noquote (f"$${ macros .rank0_env } :29500" )
198+
199+ if nnodes == 1 :
200+ rdzv_backend = "c10d"
201+ rdzv_endpoint = "localhost:29500"
202+
203+ if env is None :
204+ env = {}
205+ env .setdefault ("LOGLEVEL" , "INFO" )
206+
207+ cmd = [
208+ "python" ,
209+ "-m" ,
210+ "torch.distributed.run" ,
211+ "--rdzv_backend" ,
212+ rdzv_backend ,
213+ "--rdzv_endpoint" ,
214+ rdzv_endpoint ,
215+ "--rdzv_id" ,
216+ f"{ macros .app_id } " ,
217+ "--nnodes" ,
218+ str (nnodes ),
219+ "--nproc_per_node" ,
220+ str (nproc_per_node ),
221+ ]
222+ if max_restarts is not None :
223+ cmd += ["--max_restarts" , str (max_restarts )]
224+ if script is not None :
225+ cmd += [script ]
226+ elif m is not None :
227+ cmd += ["-m" , m ]
228+ cmd += script_args
179229 return specs .AppDef (
180- name = name or script_name_noext ,
230+ name = name or role_name ,
181231 roles = [
182232 specs .Role (
183- name = script_name_noext ,
233+ name = role_name ,
184234 image = image ,
185- entrypoint = "python " ,
235+ entrypoint = "bash " ,
186236 num_replicas = nnodes ,
187237 resource = specs .resource (cpu = cpu , gpu = gpu , memMB = memMB , h = h ),
188- args = [
189- "-m" ,
190- "torch.distributed.run" ,
191- "--rdzv_backend" ,
192- ("c10d" if nnodes == 1 else "etcd" ),
193- "--rdzv_endpoint" ,
194- ("localhost:29500" if nnodes == 1 else rdzv_endpoint ),
195- "--rdzv_id" ,
196- f"{ macros .app_id } " ,
197- "--nnodes" ,
198- str (nnodes ),
199- "--nproc_per_node" ,
200- str (nproc_per_node ),
201- script ,
202- * script_args ,
203- ],
204- env = env or {},
238+ args = ["-c" , _args_join (cmd )],
239+ env = env ,
240+ port_map = {
241+ "c10d" : 29500 ,
242+ },
205243 )
206244 ],
207245 )
246+
247+
248+ def _args_join (args : Iterable [str ]) -> str :
249+ """
250+ _args_join is like shlex.join but if the argument is wrapped in _noquote
251+ it'll not quote that argument.
252+ """
253+ quoted = [arg if isinstance (arg , _noquote ) else shlex .quote (arg ) for arg in args ]
254+ return " " .join (quoted )
255+
256+
257+ class _noquote (str ):
258+ """
259+ _noquote is a wrapper around str that indicates that the argument shouldn't
260+ be passed through shlex.quote.
261+ """
262+
263+ pass
0 commit comments