Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/monarch/_src/actor/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def flush_all_proc_mesh_logs(v1: bool = False) -> None:
from monarch._src.actor.v1.proc_mesh import get_active_proc_meshes

for pm in get_active_proc_meshes():
if pm._logging_manager._logging_mesh_client is not None:
pm._logging_manager.flush()
pm._logging_manager.flush()


class LoggingManager:
Expand Down
3 changes: 2 additions & 1 deletion python/monarch/_src/actor/v1/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ async def task(
) -> HyProcMesh:
hy_proc_mesh = await hy_proc_mesh_task

await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client)
# FIXME: Fix log forwarding.
# await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client)

if setup_actor is not None:
await setup_actor.setup.call()
Expand Down
129 changes: 36 additions & 93 deletions python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,24 @@
PythonMessage,
PythonMessageKind,
)
from monarch._rust_bindings.monarch_hyperactor.alloc import Alloc, AllocSpec
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
PortId,
PortRef,
UndeliverableMessageEnvelope,
)
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
from monarch._rust_bindings.monarch_hyperactor.shape import Extent

from monarch._src.actor.actor_mesh import ActorMesh, Channel, context, Port
from monarch._src.actor.allocator import AllocHandle, ProcessAllocator
from monarch._src.actor.allocator import AllocHandle
from monarch._src.actor.future import Future
from monarch._src.actor.host_mesh import (
create_local_host_mesh,
fake_in_process_host,
HostMesh,
)
from monarch._src.actor.proc_mesh import (
_get_bootstrap_args,
get_or_spawn_controller,
ProcMesh,
)
from monarch._src.actor.proc_mesh import get_or_spawn_controller, ProcMesh
from monarch._src.actor.v1.host_mesh import (
_bootstrap_cmd,
fake_in_process_host as fake_in_process_host_v1,
HostMesh as HostMeshV1,
this_host as this_host_v1,
Expand Down Expand Up @@ -473,7 +466,7 @@ async def no_more(self) -> None:


@pytest.mark.parametrize("v1", [True, False])
@pytest.mark.timeout(60)
@pytest.mark.timeout(30)
async def test_async_concurrency(v1: bool):
"""Test that async endpoints will be processed concurrently."""
pm = spawn_procs_on_this_host(v1, {})
Expand Down Expand Up @@ -610,9 +603,8 @@ def _handle_undeliverable_message(
return True


@pytest.mark.parametrize("v1", [True, False])
@pytest.mark.timeout(60)
async def test_actor_log_streaming(v1: bool) -> None:
async def test_actor_log_streaming() -> None:
# Save original file descriptors
original_stdout_fd = os.dup(1) # stdout
original_stderr_fd = os.dup(2) # stderr
Expand All @@ -639,7 +631,7 @@ async def test_actor_log_streaming(v1: bool) -> None:
sys.stderr = stderr_file

try:
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
pm = spawn_procs_on_this_host(v1=False, per_host={"gpus": 2})
am = pm.spawn("printer", Printer)

# Disable streaming logs to client
Expand Down Expand Up @@ -679,10 +671,7 @@ async def test_actor_log_streaming(v1: bool) -> None:
await am.print.call("has print streaming too")
await am.log.call("has log streaming as level matched")

if not v1:
await pm.stop()
else:
await asyncio.sleep(1)
await pm.stop()

# Flush all outputs
stdout_file.flush()
Expand Down Expand Up @@ -763,9 +752,8 @@ async def test_actor_log_streaming(v1: bool) -> None:
pass


@pytest.mark.parametrize("v1", [True, False])
@pytest.mark.timeout(120)
async def test_alloc_based_log_streaming(v1: bool) -> None:
async def test_alloc_based_log_streaming() -> None:
"""Test both AllocHandle.stream_logs = False and True cases."""

async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
Expand All @@ -782,50 +770,23 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:

try:
# Create proc mesh with custom stream_logs setting
if not v1:
host_mesh = create_local_host_mesh()
alloc_handle = host_mesh._alloc(hosts=1, gpus=2)

# Override the stream_logs setting
custom_alloc_handle = AllocHandle(
alloc_handle._hy_alloc, alloc_handle._extent, stream_logs
)
host_mesh = create_local_host_mesh()
alloc_handle = host_mesh._alloc(hosts=1, gpus=2)

pm = ProcMesh.from_alloc(custom_alloc_handle)
else:

class ProcessAllocatorStreamLogs(ProcessAllocator):
def allocate_nonblocking(
self, spec: AllocSpec
) -> PythonTask[Alloc]:
return super().allocate_nonblocking(spec)

def _stream_logs(self) -> bool:
return stream_logs

alloc = ProcessAllocatorStreamLogs(*_get_bootstrap_args())

host_mesh = HostMeshV1.allocate_nonblocking(
"host",
Extent(["hosts"], [1]),
alloc,
bootstrap_cmd=_bootstrap_cmd(),
)

pm = host_mesh.spawn_procs(name="proc", per_host={"gpus": 2})
# Override the stream_logs setting
custom_alloc_handle = AllocHandle(
alloc_handle._hy_alloc, alloc_handle._extent, stream_logs
)

pm = ProcMesh.from_alloc(custom_alloc_handle)
am = pm.spawn("printer", Printer)

await pm.initialized

for _ in range(5):
await am.print.call(f"{test_name} print streaming")

if not v1:
await pm.stop()
else:
# Wait for at least the aggregation window (3 seconds)
await asyncio.sleep(5)
await pm.stop()

# Flush all outputs
stdout_file.flush()
Expand All @@ -849,18 +810,18 @@ def _stream_logs(self) -> bool:
# When stream_logs=False, logs should not be streamed to client
assert not re.search(
rf"similar log lines.*{test_name} print streaming", stdout_content
), f"stream_logs=False case: {stdout_content}"
), f"stream_logs=True case: {stdout_content}"
assert re.search(
rf"{test_name} print streaming", stdout_content
), f"stream_logs=False case: {stdout_content}"
), f"stream_logs=True case: {stdout_content}"
else:
# When stream_logs=True, logs should be streamed to client (no aggregation by default)
assert re.search(
rf"similar log lines.*{test_name} print streaming", stdout_content
), f"stream_logs=True case: {stdout_content}"
), f"stream_logs=False case: {stdout_content}"
assert not re.search(
rf"\[[0-9]\]{test_name} print streaming", stdout_content
), f"stream_logs=True case: {stdout_content}"
), f"stream_logs=False case: {stdout_content}"

finally:
# Ensure file descriptors are restored even if something goes wrong
Expand All @@ -875,9 +836,8 @@ def _stream_logs(self) -> bool:
await test_stream_logs_case(True, "stream_logs_true")


@pytest.mark.parametrize("v1", [True, False])
@pytest.mark.timeout(60)
async def test_logging_option_defaults(v1: bool) -> None:
async def test_logging_option_defaults() -> None:
# Save original file descriptors
original_stdout_fd = os.dup(1) # stdout
original_stderr_fd = os.dup(2) # stderr
Expand All @@ -904,18 +864,14 @@ async def test_logging_option_defaults(v1: bool) -> None:
sys.stderr = stderr_file

try:
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
pm = spawn_procs_on_this_host(v1=False, per_host={"gpus": 2})
am = pm.spawn("printer", Printer)

for _ in range(5):
await am.print.call("print streaming")
await am.log.call("log streaming")

if not v1:
await pm.stop()
else:
# Wait for > default aggregation window (3 seconds)
await asyncio.sleep(5)
await pm.stop()

# Flush all outputs
stdout_file.flush()
Expand Down Expand Up @@ -993,8 +949,7 @@ def __init__(self):

# oss_skip: pytest keeps complaining about mocking get_ipython module
@pytest.mark.oss_skip
@pytest.mark.parametrize("v1", [True, False])
async def test_flush_called_only_once(v1: bool) -> None:
async def test_flush_called_only_once() -> None:
"""Test that flush is called only once when ending an ipython cell"""
mock_ipython = MockIPython()
with unittest.mock.patch(
Expand All @@ -1006,8 +961,8 @@ async def test_flush_called_only_once(v1: bool) -> None:
"monarch._src.actor.logging.flush_all_proc_mesh_logs"
) as mock_flush:
# Create 2 proc meshes with a large aggregation window
pm1 = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
_ = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
pm1 = this_host().spawn_procs(per_host={"gpus": 2})
_ = this_host().spawn_procs(per_host={"gpus": 2})
# flush not yet called unless post_run_cell
assert mock_flush.call_count == 0
assert mock_ipython.events.registers == 0
Expand All @@ -1021,9 +976,8 @@ async def test_flush_called_only_once(v1: bool) -> None:

# oss_skip: pytest keeps complaining about mocking get_ipython module
@pytest.mark.oss_skip
@pytest.mark.parametrize("v1", [True, False])
@pytest.mark.timeout(180)
async def test_flush_logs_ipython(v1: bool) -> None:
async def test_flush_logs_ipython() -> None:
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
# Save original file descriptors
original_stdout_fd = os.dup(1) # stdout
Expand All @@ -1049,8 +1003,8 @@ async def test_flush_logs_ipython(v1: bool) -> None:
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
# Make sure we can register and unregister callbacks
for _ in range(3):
pm1 = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
pm2 = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
pm1 = this_host().spawn_procs(per_host={"gpus": 2})
pm2 = this_host().spawn_procs(per_host={"gpus": 2})
am1 = pm1.spawn("printer", Printer)
am2 = pm2.spawn("printer", Printer)

Expand Down Expand Up @@ -1154,9 +1108,8 @@ async def test_flush_logs_fast_exit() -> None:
), process.stdout


@pytest.mark.parametrize("v1", [True, False])
@pytest.mark.timeout(60)
async def test_flush_on_disable_aggregation(v1: bool) -> None:
async def test_flush_on_disable_aggregation() -> None:
"""Test that logs are flushed when disabling aggregation.

This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation."
Expand All @@ -1177,7 +1130,7 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None:
sys.stdout = stdout_file

try:
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
pm = this_host().spawn_procs(per_host={"gpus": 2})
am = pm.spawn("printer", Printer)

# Set a long aggregation window to ensure logs aren't flushed immediately
Expand All @@ -1198,11 +1151,7 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None:
for _ in range(5):
await am.print.call("single log line")

if not v1:
await pm.stop()
else:
# Wait for > default aggregation window (3 secs)
await asyncio.sleep(5)
await pm.stop()

# Flush all outputs
stdout_file.flush()
Expand Down Expand Up @@ -1248,15 +1197,14 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None:
pass


@pytest.mark.parametrize("v1", [True, False])
@pytest.mark.timeout(120)
async def test_multiple_ongoing_flushes_no_deadlock(v1: bool) -> None:
async def test_multiple_ongoing_flushes_no_deadlock() -> None:
"""
The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked.
Because now a flush call is purely sync, it is very easy to get into a deadlock.
So we assert the last flush call will not get into such a state.
"""
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 4})
pm = this_host().spawn_procs(per_host={"gpus": 4})
am = pm.spawn("printer", Printer)

# Generate some logs that will be aggregated but not flushed immediately
Expand All @@ -1279,9 +1227,8 @@ async def test_multiple_ongoing_flushes_no_deadlock(v1: bool) -> None:
futures[-1].get()


@pytest.mark.parametrize("v1", [True, False])
@pytest.mark.timeout(60)
async def test_adjust_aggregation_window(v1: bool) -> None:
async def test_adjust_aggregation_window() -> None:
"""Test that the flush deadline is updated when the aggregation window is adjusted.

This tests the corner case: "This can happen if the user has adjusted the aggregation window."
Expand All @@ -1302,7 +1249,7 @@ async def test_adjust_aggregation_window(v1: bool) -> None:
sys.stdout = stdout_file

try:
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
pm = this_host().spawn_procs(per_host={"gpus": 2})
am = pm.spawn("printer", Printer)

# Set a long aggregation window initially
Expand All @@ -1320,11 +1267,7 @@ async def test_adjust_aggregation_window(v1: bool) -> None:
for _ in range(3):
await am.print.call("second batch of logs")

if not v1:
await pm.stop()
else:
# Wait for > aggregation window (2 secs)
await asyncio.sleep(4)
await pm.stop()

# Flush all outputs
stdout_file.flush()
Expand Down
Loading