From d8bf18881b253b5c8fb8e08ee6b17f590dd9db88 Mon Sep 17 00:00:00 2001 From: zdevito Date: Tue, 7 Oct 2025 11:36:34 -0700 Subject: [PATCH 1/2] SSHJob/LoginJob Add a simple SSHJob variant that lets you establish a host mesh via directly ssh-ing into machines. This is probably too simple for someone to use in practice but it demos what is necessary to get a monarch job running. Differential Revision: [D84016804](https://our.internmc.facebook.com/intern/diff/D84016804/) [ghstack-poisoned] --- python/monarch/_src/actor/bootstrap.py | 6 + python/monarch/_src/job/job.py | 201 +++++++++++++++++++++++-- python/tests/test_python_actors.py | 42 ++++++ 3 files changed, 239 insertions(+), 10 deletions(-) diff --git a/python/monarch/_src/actor/bootstrap.py b/python/monarch/_src/actor/bootstrap.py index f4e66bf99..60a209657 100644 --- a/python/monarch/_src/actor/bootstrap.py +++ b/python/monarch/_src/actor/bootstrap.py @@ -73,6 +73,11 @@ def run_worker_loop_forever( raise NotImplementedError("TLS security plumbing") # we maybe want to actually return the future and let you do other stuff, # not sure ... + if "tcp://*" in address: + raise NotImplementedError( + "implementation does not get the host name right if it was specified as a wild card. We have to fix this" + ) + _run_worker_loop_forever(address).block_on() @@ -104,6 +109,7 @@ def attach_to_workers( if private_key is not None or ca != "trust_all_connections": raise NotImplementedError("TLS security plumbing") + workers_tasks = [_as_python_task(w) for w in workers] host_mesh: PythonTask[HyHostMesh] = _attach_to_workers(workers_tasks, name=name) extent = Extent(["hosts"], [len(workers)]) diff --git a/python/monarch/_src/job/job.py b/python/monarch/_src/job/job.py index 842574d18..ebc249975 100644 --- a/python/monarch/_src/job/job.py +++ b/python/monarch/_src/job/job.py @@ -4,17 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import os import pickle +import shlex +import signal import subprocess import sys import tempfile from abc import ABC, abstractmethod -from typing import Dict, Literal, NamedTuple, Optional, Sequence +from typing import cast, Dict, List, Literal, NamedTuple, Optional, Sequence + +from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport +from monarch._rust_bindings.monarch_hyperactor.config import configure + +from monarch._src.actor.bootstrap import attach_to_workers # note: the jobs api is intended as a library so it should # only be importing _public_ monarch API functions. from monarch._src.actor.host_mesh import HostMesh, this_host + from typing_extensions import Self @@ -39,6 +48,12 @@ class CachedRunning(NamedTuple): job: "JobTrait" +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler(sys.stderr)) +logger.propagate = False + + class JobTrait(ABC): def __init__(self): super().__init__() @@ -102,6 +117,10 @@ def apply(self, client_script: Optional[str] = None): self._create(client_script) self._status = "running" + @property + def active(self) -> bool: + return self._running is not None + def state(self, cached_path: Optional[str] = ".monarch/job_state.pkl") -> JobState: """ Get the current state of this job, containing the host mesh objects of its requires that were requested @@ -124,30 +143,44 @@ def state(self, cached_path: Optional[str] = ".monarch/job_state.pkl") -> JobSta # calls to attach_to_workers and return the HostMeshes running_job = self._running if running_job is not None: + logger.info("Job is running, returning current state") return running_job._state() cached = self._load_cached(cached_path) if cached is not None: self._status = CachedRunning(cached) + logger.info("Connecting to cached job") return cached._state() + logger.info("Applying current job") self.apply() if cached_path is not None: # Create the directory for cached_path if it doesn't exist cache_dir = os.path.dirname(cached_path) if cache_dir: # Only create if there's a directory component os.makedirs(cache_dir, exist_ok=True) + logger.info("Saving job to cache at %s", cached_path) self.dump(cached_path) + logger.info("Job has started, connecting to current state") return self._state() def _load_cached(self, cached_path: Optional[str]) -> "Optional[JobTrait]": if cached_path is None: + logger.info("No cached path provided") return None try: job = job_load(cached_path) + logger.info("Found cached job at path: %s", cached_path) except FileNotFoundError: + logger.info("No cached job found at path: %s", cached_path) return None running = job._running - if running is None or not running.can_run(self): + if running is None: + logger.info("Cached job is not running") + return None + if not running.can_run(self): + logger.info("Cached job cannot run this spec, removing cache") + running._kill() + os.remove(cached_path) return None return job @@ -164,6 +197,12 @@ def dumps(self) -> bytes: # @lint-ignore PYTHONPICKLEISBAD return pickle.dumps(self) + def kill(self): + running = self._running + if running is not None: + running._kill() + self._status = "not_running" + @abstractmethod def _state(self) -> JobState: ... @@ -181,11 +220,6 @@ def can_run(self, spec: "JobTrait") -> bool: ... - def kill(self): - running = self._running - if running is not None: - running._kill() - @abstractmethod def _kill(self): """ @@ -244,8 +278,10 @@ def _create(self, client_script: Optional[str]): log_dir = self._setup_log_directory() self._run_client_as_daemon(client_script, log_dir) - print(f"Started client script {client_script} with PID: {self.process.pid}") - print(f"Logs available at: {log_dir}") + logger.info( + "Started client script %s with PID: %d", client_script, self.process.pid + ) + logger.info("Logs available at: %s", log_dir) def _setup_log_directory(self) -> str: """Create a log directory for the batch job.""" @@ -323,5 +359,150 @@ def _create(self, client_script: Optional[str] = None): return self._job._create(client_script) def _kill(self): - print("Stopping Batch Job") + logger.info("Stopping Batch Job") return self._job._kill() + + +class ProcessState(NamedTuple): + pid: int + channel: str + + +class LoginJob(JobTrait): + """ + Makes a connections directly to hosts via an explicit list. + """ + + def __init__(self): + super().__init__() + self._meshes: Dict[str, List[str]] = {} + self._host_to_pid: Dict[str, ProcessState] = {} + + def add_mesh(self, name: str, hosts: List[str]): + self._meshes[name] = hosts + + def _state(self) -> JobState: + if not self._pids_active(): + raise RuntimeError("lost connection") + hosts = { + name: cast( + "HostMesh", + attach_to_workers( + name=name, + ca="trust_all_connections", + workers=[self._host_to_pid[v].channel for v in values], + ), + ) + for name, values in self._meshes.items() + } + return JobState(hosts) + + def _create(self, client_script: Optional[str]): + if client_script is not None: + raise RuntimeError("LoginJob cannot run batch-mode scripts") + + for hosts in self._meshes.values(): + for host in hosts: + self._host_to_pid[host] = self._start_host(host) + + @abstractmethod + def _start_host(self, host: str) -> ProcessState: ... + + def can_run(self, spec: "JobTrait") -> bool: + """ + Is this job capable of running the job spec? This is used to check if a + cached job can be used to run `spec` instead of creating a new reserveration. + + It is also used by the batch run infrastructure to indicate that the batch job can certainly run itself. + """ + return ( + isinstance(spec, LoginJob) + and spec._meshes == self._meshes + and self._pids_active() + ) + + def _pids_active(self) -> bool: + if not self.active: + return False + for _, p in self._host_to_pid.items(): + try: + # Check if process exists by sending signal 0 + os.kill(p.pid, 0) + except OSError: + # Process doesn't exist or we don't have permission to signal it + return False + return True + + def _kill(self): + for p in self._host_to_pid.values(): + try: + os.kill(p.pid, signal.SIGKILL) + except OSError: + pass + + +class FakeLocalLoginJob(LoginJob): + """ + + Fake it that we are logging in by just making a local process that runs the bootstrap. + """ + + def __init__(self): + super().__init__() + configure(default_transport=ChannelTransport.Tcp) + + self._next_port = 12345 + + def _start_host(self, host: str) -> ProcessState: + port = self._next_port + self._next_port += 1 + + env = {**os.environ} + if "FB_XAR_INVOKED_NAME" in os.environ: + env["PYTHONPATH"] = ":".join(sys.path) + addr = f"tcp://[::1]:{port}" + bind_addr = f"tcp://[::1]:{port}" + proc = subprocess.Popen( + [ + sys.executable, + "-c", + f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={repr(bind_addr)}, ca="trust_all_connections")', + ], + env=env, + start_new_session=True, + ) + return ProcessState(proc.pid, addr) + + +class SSHJob(LoginJob): + def __init__( + self, + python_exe: str = "python", + ssh_args: Sequence[str] = (), + monarch_port: int = 22222, + ): + configure(default_transport=ChannelTransport.Tcp) + self._python_exe = python_exe + self._ssh_args = ssh_args + self._port = monarch_port + super().__init__() + + def _start_host(self, host: str) -> ProcessState: + addr = f"tcp://{host}:{self._port}" + startup = f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={repr(addr)}, ca="trust_all_connections")' + + command = f"{shlex.quote(self._python_exe)} -c {shlex.quote(startup)}" + proc = subprocess.Popen( + ["ssh", *self._ssh_args, host, "-n", command], + start_new_session=True, + ) + return ProcessState(proc.pid, addr) + + def can_run(self, spec): + return ( + isinstance(spec, SSHJob) + and spec._python_exe == self._python_exe + and self._port == spec._port + and self._ssh_args == spec._ssh_args + and super().can_run(spec) + ) diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index ab2de38c5..4ba507d8c 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -62,6 +62,7 @@ this_proc as this_proc_v1, ) from monarch._src.actor.v1.proc_mesh import ProcMesh as ProcMeshV1 +from monarch._src.job.job import JobState, LoginJob, ProcessState from monarch.actor import ( Accumulator, @@ -1721,3 +1722,44 @@ def test_simple_bootstrap(): for proc in procs: proc.kill() proc.wait() + + +class FakeLocalLoginJob(LoginJob): + """ + + Fake it that we are logging in by just making a local process that runs the bootstrap. + """ + + def __init__(self, dir: str): + super().__init__() + self._dir = dir + + def _start_host(self, host: str) -> ProcessState: + env = {**os.environ} + if "FB_XAR_INVOKED_NAME" in os.environ: + env["PYTHONPATH"] = ":".join(sys.path) + addr = f"ipc://{self._dir}/{host}" + proc = subprocess.Popen( + [ + sys.executable, + "-c", + f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={repr(addr)}, ca="trust_all_connections")', + ], + env=env, + start_new_session=True, + ) + return ProcessState(proc.pid, addr) + + +def test_login_job(): + with TemporaryDirectory() as temp_dir: + j = FakeLocalLoginJob(temp_dir) + j.add_mesh("hosts", ["fake", "hosts"]) + state = j.state(cached_path=None) + + hello = state.hosts.spawn_procs().spawn("hello", Hello) + r = hello.doit.call().get() + for _, v in r.items(): + assert v == "hello!" + + j.kill() From 7d466211dc9b1e5cad32c72a38788b0d7e75ce49 Mon Sep 17 00:00:00 2001 From: zdevito Date: Tue, 7 Oct 2025 12:08:54 -0700 Subject: [PATCH 2/2] Update on "SSHJob/LoginJob" Add a simple SSHJob variant that lets you establish a host mesh via directly ssh-ing into machines. This is probably too simple for someone to use in practice but it demos what is necessary to get a monarch job running. Differential Revision: [D84016804](https://our.internmc.facebook.com/intern/diff/D84016804/) [ghstack-poisoned]