Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/monarch/_src/actor/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def run_worker_loop_forever(
raise NotImplementedError("TLS security plumbing")
# we maybe want to actually return the future and let you do other stuff,
# not sure ...
if "tcp://*" in address:
raise NotImplementedError(
"implementation does not get the host name right if it was specified as a wild card. We have to fix this"
)

_run_worker_loop_forever(address).block_on()


Expand Down Expand Up @@ -104,6 +109,7 @@ def attach_to_workers(

if private_key is not None or ca != "trust_all_connections":
raise NotImplementedError("TLS security plumbing")

workers_tasks = [_as_python_task(w) for w in workers]
host_mesh: PythonTask[HyHostMesh] = _attach_to_workers(workers_tasks, name=name)
extent = Extent(["hosts"], [len(workers)])
Expand Down
201 changes: 191 additions & 10 deletions python/monarch/_src/job/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,26 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import pickle
import shlex
import signal
import subprocess
import sys
import tempfile
from abc import ABC, abstractmethod
from typing import Dict, Literal, NamedTuple, Optional, Sequence
from typing import cast, Dict, List, Literal, NamedTuple, Optional, Sequence

from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
from monarch._rust_bindings.monarch_hyperactor.config import configure

from monarch._src.actor.bootstrap import attach_to_workers

# note: the jobs api is intended as a library so it should
# only be importing _public_ monarch API functions.
from monarch._src.actor.host_mesh import HostMesh, this_host

from typing_extensions import Self


Expand All @@ -39,6 +48,12 @@ class CachedRunning(NamedTuple):
job: "JobTrait"


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stderr))
logger.propagate = False


class JobTrait(ABC):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -102,6 +117,10 @@ def apply(self, client_script: Optional[str] = None):
self._create(client_script)
self._status = "running"

@property
def active(self) -> bool:
return self._running is not None

def state(self, cached_path: Optional[str] = ".monarch/job_state.pkl") -> JobState:
"""
Get the current state of this job, containing the host mesh objects of its requires that were requested
Expand All @@ -124,30 +143,44 @@ def state(self, cached_path: Optional[str] = ".monarch/job_state.pkl") -> JobSta
# calls to attach_to_workers and return the HostMeshes
running_job = self._running
if running_job is not None:
logger.info("Job is running, returning current state")
return running_job._state()

cached = self._load_cached(cached_path)
if cached is not None:
self._status = CachedRunning(cached)
logger.info("Connecting to cached job")
return cached._state()
logger.info("Applying current job")
self.apply()
if cached_path is not None:
# Create the directory for cached_path if it doesn't exist
cache_dir = os.path.dirname(cached_path)
if cache_dir: # Only create if there's a directory component
os.makedirs(cache_dir, exist_ok=True)
logger.info("Saving job to cache at %s", cached_path)
self.dump(cached_path)
logger.info("Job has started, connecting to current state")
return self._state()

def _load_cached(self, cached_path: Optional[str]) -> "Optional[JobTrait]":
if cached_path is None:
logger.info("No cached path provided")
return None
try:
job = job_load(cached_path)
logger.info("Found cached job at path: %s", cached_path)
except FileNotFoundError:
logger.info("No cached job found at path: %s", cached_path)
return None
running = job._running
if running is None or not running.can_run(self):
if running is None:
logger.info("Cached job is not running")
return None
if not running.can_run(self):
logger.info("Cached job cannot run this spec, removing cache")
running._kill()
os.remove(cached_path)
return None
return job

Expand All @@ -164,6 +197,12 @@ def dumps(self) -> bytes:
# @lint-ignore PYTHONPICKLEISBAD
return pickle.dumps(self)

def kill(self):
running = self._running
if running is not None:
running._kill()
self._status = "not_running"

@abstractmethod
def _state(self) -> JobState: ...

Expand All @@ -181,11 +220,6 @@ def can_run(self, spec: "JobTrait") -> bool:

...

def kill(self):
running = self._running
if running is not None:
running._kill()

@abstractmethod
def _kill(self):
"""
Expand Down Expand Up @@ -244,8 +278,10 @@ def _create(self, client_script: Optional[str]):
log_dir = self._setup_log_directory()
self._run_client_as_daemon(client_script, log_dir)

print(f"Started client script {client_script} with PID: {self.process.pid}")
print(f"Logs available at: {log_dir}")
logger.info(
"Started client script %s with PID: %d", client_script, self.process.pid
)
logger.info("Logs available at: %s", log_dir)

def _setup_log_directory(self) -> str:
"""Create a log directory for the batch job."""
Expand Down Expand Up @@ -323,5 +359,150 @@ def _create(self, client_script: Optional[str] = None):
return self._job._create(client_script)

def _kill(self):
print("Stopping Batch Job")
logger.info("Stopping Batch Job")
return self._job._kill()


class ProcessState(NamedTuple):
pid: int
channel: str


class LoginJob(JobTrait):
"""
Makes a connections directly to hosts via an explicit list.
"""

def __init__(self):
super().__init__()
self._meshes: Dict[str, List[str]] = {}
self._host_to_pid: Dict[str, ProcessState] = {}

def add_mesh(self, name: str, hosts: List[str]):
self._meshes[name] = hosts

def _state(self) -> JobState:
if not self._pids_active():
raise RuntimeError("lost connection")
hosts = {
name: cast(
"HostMesh",
attach_to_workers(
name=name,
ca="trust_all_connections",
workers=[self._host_to_pid[v].channel for v in values],
),
)
for name, values in self._meshes.items()
}
return JobState(hosts)

def _create(self, client_script: Optional[str]):
if client_script is not None:
raise RuntimeError("LoginJob cannot run batch-mode scripts")

for hosts in self._meshes.values():
for host in hosts:
self._host_to_pid[host] = self._start_host(host)

@abstractmethod
def _start_host(self, host: str) -> ProcessState: ...

def can_run(self, spec: "JobTrait") -> bool:
"""
Is this job capable of running the job spec? This is used to check if a
cached job can be used to run `spec` instead of creating a new reserveration.

It is also used by the batch run infrastructure to indicate that the batch job can certainly run itself.
"""
return (
isinstance(spec, LoginJob)
and spec._meshes == self._meshes
and self._pids_active()
)

def _pids_active(self) -> bool:
if not self.active:
return False
for _, p in self._host_to_pid.items():
try:
# Check if process exists by sending signal 0
os.kill(p.pid, 0)
except OSError:
# Process doesn't exist or we don't have permission to signal it
return False
return True

def _kill(self):
for p in self._host_to_pid.values():
try:
os.kill(p.pid, signal.SIGKILL)
except OSError:
pass


class FakeLocalLoginJob(LoginJob):
"""

Fake it that we are logging in by just making a local process that runs the bootstrap.
"""

def __init__(self):
super().__init__()
configure(default_transport=ChannelTransport.Tcp)

self._next_port = 12345

def _start_host(self, host: str) -> ProcessState:
port = self._next_port
self._next_port += 1

env = {**os.environ}
if "FB_XAR_INVOKED_NAME" in os.environ:
env["PYTHONPATH"] = ":".join(sys.path)
addr = f"tcp://[::1]:{port}"
bind_addr = f"tcp://[::1]:{port}"
proc = subprocess.Popen(
[
sys.executable,
"-c",
f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={repr(bind_addr)}, ca="trust_all_connections")',
],
env=env,
start_new_session=True,
)
return ProcessState(proc.pid, addr)


class SSHJob(LoginJob):
def __init__(
self,
python_exe: str = "python",
ssh_args: Sequence[str] = (),
monarch_port: int = 22222,
):
configure(default_transport=ChannelTransport.Tcp)
self._python_exe = python_exe
self._ssh_args = ssh_args
self._port = monarch_port
super().__init__()

def _start_host(self, host: str) -> ProcessState:
addr = f"tcp://{host}:{self._port}"
startup = f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={repr(addr)}, ca="trust_all_connections")'

command = f"{shlex.quote(self._python_exe)} -c {shlex.quote(startup)}"
proc = subprocess.Popen(
["ssh", *self._ssh_args, host, "-n", command],
start_new_session=True,
)
return ProcessState(proc.pid, addr)

def can_run(self, spec):
return (
isinstance(spec, SSHJob)
and spec._python_exe == self._python_exe
and self._port == spec._port
and self._ssh_args == spec._ssh_args
and super().can_run(spec)
)
42 changes: 42 additions & 0 deletions python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
get_or_spawn_controller as get_or_spawn_controller_v1,
ProcMesh as ProcMeshV1,
)
from monarch._src.job.job import LoginJob, ProcessState

from monarch.actor import (
Accumulator,
Expand Down Expand Up @@ -1800,3 +1801,44 @@ def test_this_host() -> None:
expected_hosts_by_rank[6],
expected_hosts_by_rank[10],
]


class FakeLocalLoginJob(LoginJob):
"""

Fake it that we are logging in by just making a local process that runs the bootstrap.
"""

def __init__(self, dir: str):
super().__init__()
self._dir = dir

def _start_host(self, host: str) -> ProcessState:
env = {**os.environ}
if "FB_XAR_INVOKED_NAME" in os.environ:
env["PYTHONPATH"] = ":".join(sys.path)
addr = f"ipc://{self._dir}/{host}"
proc = subprocess.Popen(
[
sys.executable,
"-c",
f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={repr(addr)}, ca="trust_all_connections")',
],
env=env,
start_new_session=True,
)
return ProcessState(proc.pid, addr)


def test_login_job():
with TemporaryDirectory() as temp_dir:
j = FakeLocalLoginJob(temp_dir)
j.add_mesh("hosts", ["fake", "hosts"])
state = j.state(cached_path=None)

hello = state.hosts.spawn_procs().spawn("hello", Hello)
r = hello.doit.call().get()
for _, v in r.items():
assert v == "hello!"

j.kill()