44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import logging
78import os
89import pickle
10+ import shlex
11+ import signal
912import subprocess
1013import sys
1114import tempfile
1215from abc import ABC , abstractmethod
13- from typing import Dict , Literal , NamedTuple , Optional , Sequence
16+ from typing import cast , Dict , List , Literal , NamedTuple , Optional , Sequence
17+
18+ from monarch ._rust_bindings .monarch_hyperactor .channel import ChannelTransport
19+ from monarch ._rust_bindings .monarch_hyperactor .config import configure
20+
21+ from monarch ._src .actor .bootstrap import attach_to_workers
1422
1523# note: the jobs api is intended as a library so it should
1624# only be importing _public_ monarch API functions.
1725from monarch ._src .actor .host_mesh import HostMesh , this_host
26+
1827from typing_extensions import Self
1928
2029
@@ -39,6 +48,12 @@ class CachedRunning(NamedTuple):
3948 job : "JobTrait"
4049
4150
51+ logger = logging .getLogger (__name__ )
52+ logger .setLevel (logging .INFO )
53+ logger .addHandler (logging .StreamHandler (sys .stderr ))
54+ logger .propagate = False
55+
56+
4257class JobTrait (ABC ):
4358 def __init__ (self ):
4459 super ().__init__ ()
@@ -102,6 +117,10 @@ def apply(self, client_script: Optional[str] = None):
102117 self ._create (client_script )
103118 self ._status = "running"
104119
120+ @property
121+ def active (self ) -> bool :
122+ return self ._running is not None
123+
105124 def state (self , cached_path : Optional [str ] = ".monarch/job_state.pkl" ) -> JobState :
106125 """
107126 Get the current state of this job, containing the host mesh objects of its requires that were requested
@@ -124,30 +143,44 @@ def state(self, cached_path: Optional[str] = ".monarch/job_state.pkl") -> JobSta
124143 # calls to attach_to_workers and return the HostMeshes
125144 running_job = self ._running
126145 if running_job is not None :
146+ logger .info ("Job is running, returning current state" )
127147 return running_job ._state ()
128148
129149 cached = self ._load_cached (cached_path )
130150 if cached is not None :
131151 self ._status = CachedRunning (cached )
152+ logger .info ("Connecting to cached job" )
132153 return cached ._state ()
154+ logger .info ("Applying current job" )
133155 self .apply ()
134156 if cached_path is not None :
135157 # Create the directory for cached_path if it doesn't exist
136158 cache_dir = os .path .dirname (cached_path )
137159 if cache_dir : # Only create if there's a directory component
138160 os .makedirs (cache_dir , exist_ok = True )
161+ logger .info ("Saving job to cache at %s" , cached_path )
139162 self .dump (cached_path )
163+ logger .info ("Job has started, connecting to current state" )
140164 return self ._state ()
141165
142166 def _load_cached (self , cached_path : Optional [str ]) -> "Optional[JobTrait]" :
143167 if cached_path is None :
168+ logger .info ("No cached path provided" )
144169 return None
145170 try :
146171 job = job_load (cached_path )
172+ logger .info ("Found cached job at path: %s" , cached_path )
147173 except FileNotFoundError :
174+ logger .info ("No cached job found at path: %s" , cached_path )
148175 return None
149176 running = job ._running
150- if running is None or not running .can_run (self ):
177+ if running is None :
178+ logger .info ("Cached job is not running" )
179+ return None
180+ if not running .can_run (self ):
181+ logger .info ("Cached job cannot run this spec, removing cache" )
182+ running ._kill ()
183+ os .remove (cached_path )
151184 return None
152185 return job
153186
@@ -164,6 +197,12 @@ def dumps(self) -> bytes:
164197 # @lint-ignore PYTHONPICKLEISBAD
165198 return pickle .dumps (self )
166199
200+ def kill (self ):
201+ running = self ._running
202+ if running is not None :
203+ running ._kill ()
204+ self ._status = "not_running"
205+
167206 @abstractmethod
168207 def _state (self ) -> JobState : ...
169208
@@ -181,11 +220,6 @@ def can_run(self, spec: "JobTrait") -> bool:
181220
182221 ...
183222
184- def kill (self ):
185- running = self ._running
186- if running is not None :
187- running ._kill ()
188-
189223 @abstractmethod
190224 def _kill (self ):
191225 """
@@ -244,8 +278,10 @@ def _create(self, client_script: Optional[str]):
244278 log_dir = self ._setup_log_directory ()
245279 self ._run_client_as_daemon (client_script , log_dir )
246280
247- print (f"Started client script { client_script } with PID: { self .process .pid } " )
248- print (f"Logs available at: { log_dir } " )
281+ logger .info (
282+ "Started client script %s with PID: %d" , client_script , self .process .pid
283+ )
284+ logger .info ("Logs available at: %s" , log_dir )
249285
250286 def _setup_log_directory (self ) -> str :
251287 """Create a log directory for the batch job."""
@@ -323,5 +359,150 @@ def _create(self, client_script: Optional[str] = None):
323359 return self ._job ._create (client_script )
324360
325361 def _kill (self ):
326- print ("Stopping Batch Job" )
362+ logger . info ("Stopping Batch Job" )
327363 return self ._job ._kill ()
364+
365+
366+ class ProcessState (NamedTuple ):
367+ pid : int
368+ channel : str
369+
370+
371+ class LoginJob (JobTrait ):
372+ """
373+ Makes a connections directly to hosts via an explicit list.
374+ """
375+
376+ def __init__ (self ):
377+ super ().__init__ ()
378+ self ._meshes : Dict [str , List [str ]] = {}
379+ self ._host_to_pid : Dict [str , ProcessState ] = {}
380+
381+ def add_mesh (self , name : str , hosts : List [str ]):
382+ self ._meshes [name ] = hosts
383+
384+ def _state (self ) -> JobState :
385+ if not self ._pids_active ():
386+ raise RuntimeError ("lost connection" )
387+ hosts = {
388+ name : cast (
389+ "HostMesh" ,
390+ attach_to_workers (
391+ name = name ,
392+ ca = "trust_all_connections" ,
393+ workers = [self ._host_to_pid [v ].channel for v in values ],
394+ ),
395+ )
396+ for name , values in self ._meshes .items ()
397+ }
398+ return JobState (hosts )
399+
400+ def _create (self , client_script : Optional [str ]):
401+ if client_script is not None :
402+ raise RuntimeError ("LoginJob cannot run batch-mode scripts" )
403+
404+ for hosts in self ._meshes .values ():
405+ for host in hosts :
406+ self ._host_to_pid [host ] = self ._start_host (host )
407+
408+ @abstractmethod
409+ def _start_host (self , host : str ) -> ProcessState : ...
410+
411+ def can_run (self , spec : "JobTrait" ) -> bool :
412+ """
413+ Is this job capable of running the job spec? This is used to check if a
414+ cached job can be used to run `spec` instead of creating a new reserveration.
415+
416+ It is also used by the batch run infrastructure to indicate that the batch job can certainly run itself.
417+ """
418+ return (
419+ isinstance (spec , LoginJob )
420+ and spec ._meshes == self ._meshes
421+ and self ._pids_active ()
422+ )
423+
424+ def _pids_active (self ) -> bool :
425+ if not self .active :
426+ return False
427+ for _ , p in self ._host_to_pid .items ():
428+ try :
429+ # Check if process exists by sending signal 0
430+ os .kill (p .pid , 0 )
431+ except OSError :
432+ # Process doesn't exist or we don't have permission to signal it
433+ return False
434+ return True
435+
436+ def _kill (self ):
437+ for p in self ._host_to_pid .values ():
438+ try :
439+ os .kill (p .pid , signal .SIGKILL )
440+ except OSError :
441+ pass
442+
443+
444+ class FakeLocalLoginJob (LoginJob ):
445+ """
446+
447+ Fake it that we are logging in by just making a local process that runs the bootstrap.
448+ """
449+
450+ def __init__ (self ):
451+ super ().__init__ ()
452+ configure (default_transport = ChannelTransport .Tcp )
453+
454+ self ._next_port = 12345
455+
456+ def _start_host (self , host : str ) -> ProcessState :
457+ port = self ._next_port
458+ self ._next_port += 1
459+
460+ env = {** os .environ }
461+ if "FB_XAR_INVOKED_NAME" in os .environ :
462+ env ["PYTHONPATH" ] = ":" .join (sys .path )
463+ addr = f"tcp://[::1]:{ port } "
464+ bind_addr = f"tcp://[::1]:{ port } "
465+ proc = subprocess .Popen (
466+ [
467+ sys .executable ,
468+ "-c" ,
469+ f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={ repr (bind_addr )} , ca="trust_all_connections")' ,
470+ ],
471+ env = env ,
472+ start_new_session = True ,
473+ )
474+ return ProcessState (proc .pid , addr )
475+
476+
477+ class SSHJob (LoginJob ):
478+ def __init__ (
479+ self ,
480+ python_exe : str = "python" ,
481+ ssh_args : Sequence [str ] = (),
482+ monarch_port : int = 22222 ,
483+ ):
484+ configure (default_transport = ChannelTransport .Tcp )
485+ self ._python_exe = python_exe
486+ self ._ssh_args = ssh_args
487+ self ._port = monarch_port
488+ super ().__init__ ()
489+
490+ def _start_host (self , host : str ) -> ProcessState :
491+ addr = f"tcp://{ host } :{ self ._port } "
492+ startup = f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={ repr (addr )} , ca="trust_all_connections")'
493+
494+ command = f"{ shlex .quote (self ._python_exe )} -c { shlex .quote (startup )} "
495+ proc = subprocess .Popen (
496+ ["ssh" , * self ._ssh_args , host , "-n" , command ],
497+ start_new_session = True ,
498+ )
499+ return ProcessState (proc .pid , addr )
500+
501+ def can_run (self , spec ):
502+ return (
503+ isinstance (spec , SSHJob )
504+ and spec ._python_exe == self ._python_exe
505+ and self ._port == spec ._port
506+ and self ._ssh_args == spec ._ssh_args
507+ and super ().can_run (spec )
508+ )
0 commit comments