diff --git a/python/monarch/_src/actor/logging.py b/python/monarch/_src/actor/logging.py index fa641d2f1..37a06c540 100644 --- a/python/monarch/_src/actor/logging.py +++ b/python/monarch/_src/actor/logging.py @@ -50,8 +50,7 @@ def flush_all_proc_mesh_logs(v1: bool = False) -> None: from monarch._src.actor.v1.proc_mesh import get_active_proc_meshes for pm in get_active_proc_meshes(): - if pm._logging_manager._logging_mesh_client is not None: - pm._logging_manager.flush() + pm._logging_manager.flush() class LoggingManager: diff --git a/python/monarch/_src/actor/v1/proc_mesh.py b/python/monarch/_src/actor/v1/proc_mesh.py index 5ce96c970..21f00852d 100644 --- a/python/monarch/_src/actor/v1/proc_mesh.py +++ b/python/monarch/_src/actor/v1/proc_mesh.py @@ -200,7 +200,8 @@ async def task( ) -> HyProcMesh: hy_proc_mesh = await hy_proc_mesh_task - await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client) + # FIXME: Fix log forwarding. + # await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client) if setup_actor is not None: await setup_actor.setup.call() diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 6f4a4d42f..705df5966 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -30,7 +30,6 @@ PythonMessage, PythonMessageKind, ) -from monarch._rust_bindings.monarch_hyperactor.alloc import Alloc, AllocSpec from monarch._rust_bindings.monarch_hyperactor.mailbox import ( PortId, PortRef, @@ -38,23 +37,17 @@ ) from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask -from monarch._rust_bindings.monarch_hyperactor.shape import Extent from monarch._src.actor.actor_mesh import ActorMesh, Channel, context, Port -from monarch._src.actor.allocator import AllocHandle, ProcessAllocator +from monarch._src.actor.allocator import AllocHandle from monarch._src.actor.future import Future from monarch._src.actor.host_mesh import ( create_local_host_mesh, fake_in_process_host, HostMesh, ) -from monarch._src.actor.proc_mesh import ( - _get_bootstrap_args, - get_or_spawn_controller, - ProcMesh, -) +from monarch._src.actor.proc_mesh import get_or_spawn_controller, ProcMesh from monarch._src.actor.v1.host_mesh import ( - _bootstrap_cmd, fake_in_process_host as fake_in_process_host_v1, HostMesh as HostMeshV1, this_host as this_host_v1, @@ -473,7 +466,7 @@ async def no_more(self) -> None: @pytest.mark.parametrize("v1", [True, False]) -@pytest.mark.timeout(60) +@pytest.mark.timeout(30) async def test_async_concurrency(v1: bool): """Test that async endpoints will be processed concurrently.""" pm = spawn_procs_on_this_host(v1, {}) @@ -610,9 +603,8 @@ def _handle_undeliverable_message( return True -@pytest.mark.parametrize("v1", [True, False]) @pytest.mark.timeout(60) -async def test_actor_log_streaming(v1: bool) -> None: +async def test_actor_log_streaming() -> None: # Save original file descriptors original_stdout_fd = os.dup(1) # stdout original_stderr_fd = os.dup(2) # stderr @@ -639,7 +631,7 @@ async def test_actor_log_streaming(v1: bool) -> None: sys.stderr = stderr_file try: - pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2}) + pm = spawn_procs_on_this_host(v1=False, per_host={"gpus": 2}) am = pm.spawn("printer", Printer) # Disable streaming logs to client @@ -679,10 +671,7 @@ async def test_actor_log_streaming(v1: bool) -> None: await am.print.call("has print streaming too") await am.log.call("has log streaming as level matched") - if not v1: - await pm.stop() - else: - await asyncio.sleep(1) + await pm.stop() # Flush all outputs stdout_file.flush() @@ -763,9 +752,8 @@ async def test_actor_log_streaming(v1: bool) -> None: pass -@pytest.mark.parametrize("v1", [True, False]) @pytest.mark.timeout(120) -async def test_alloc_based_log_streaming(v1: bool) -> None: +async def test_alloc_based_log_streaming() -> None: """Test both AllocHandle.stream_logs = False and True cases.""" async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None: @@ -782,38 +770,15 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None: try: # Create proc mesh with custom stream_logs setting - if not v1: - host_mesh = create_local_host_mesh() - alloc_handle = host_mesh._alloc(hosts=1, gpus=2) - - # Override the stream_logs setting - custom_alloc_handle = AllocHandle( - alloc_handle._hy_alloc, alloc_handle._extent, stream_logs - ) + host_mesh = create_local_host_mesh() + alloc_handle = host_mesh._alloc(hosts=1, gpus=2) - pm = ProcMesh.from_alloc(custom_alloc_handle) - else: - - class ProcessAllocatorStreamLogs(ProcessAllocator): - def allocate_nonblocking( - self, spec: AllocSpec - ) -> PythonTask[Alloc]: - return super().allocate_nonblocking(spec) - - def _stream_logs(self) -> bool: - return stream_logs - - alloc = ProcessAllocatorStreamLogs(*_get_bootstrap_args()) - - host_mesh = HostMeshV1.allocate_nonblocking( - "host", - Extent(["hosts"], [1]), - alloc, - bootstrap_cmd=_bootstrap_cmd(), - ) - - pm = host_mesh.spawn_procs(name="proc", per_host={"gpus": 2}) + # Override the stream_logs setting + custom_alloc_handle = AllocHandle( + alloc_handle._hy_alloc, alloc_handle._extent, stream_logs + ) + pm = ProcMesh.from_alloc(custom_alloc_handle) am = pm.spawn("printer", Printer) await pm.initialized @@ -821,11 +786,7 @@ def _stream_logs(self) -> bool: for _ in range(5): await am.print.call(f"{test_name} print streaming") - if not v1: - await pm.stop() - else: - # Wait for at least the aggregation window (3 seconds) - await asyncio.sleep(5) + await pm.stop() # Flush all outputs stdout_file.flush() @@ -849,18 +810,18 @@ def _stream_logs(self) -> bool: # When stream_logs=False, logs should not be streamed to client assert not re.search( rf"similar log lines.*{test_name} print streaming", stdout_content - ), f"stream_logs=False case: {stdout_content}" + ), f"stream_logs=True case: {stdout_content}" assert re.search( rf"{test_name} print streaming", stdout_content - ), f"stream_logs=False case: {stdout_content}" + ), f"stream_logs=True case: {stdout_content}" else: # When stream_logs=True, logs should be streamed to client (no aggregation by default) assert re.search( rf"similar log lines.*{test_name} print streaming", stdout_content - ), f"stream_logs=True case: {stdout_content}" + ), f"stream_logs=False case: {stdout_content}" assert not re.search( rf"\[[0-9]\]{test_name} print streaming", stdout_content - ), f"stream_logs=True case: {stdout_content}" + ), f"stream_logs=False case: {stdout_content}" finally: # Ensure file descriptors are restored even if something goes wrong @@ -875,9 +836,8 @@ def _stream_logs(self) -> bool: await test_stream_logs_case(True, "stream_logs_true") -@pytest.mark.parametrize("v1", [True, False]) @pytest.mark.timeout(60) -async def test_logging_option_defaults(v1: bool) -> None: +async def test_logging_option_defaults() -> None: # Save original file descriptors original_stdout_fd = os.dup(1) # stdout original_stderr_fd = os.dup(2) # stderr @@ -904,18 +864,14 @@ async def test_logging_option_defaults(v1: bool) -> None: sys.stderr = stderr_file try: - pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2}) + pm = spawn_procs_on_this_host(v1=False, per_host={"gpus": 2}) am = pm.spawn("printer", Printer) for _ in range(5): await am.print.call("print streaming") await am.log.call("log streaming") - if not v1: - await pm.stop() - else: - # Wait for > default aggregation window (3 seconds) - await asyncio.sleep(5) + await pm.stop() # Flush all outputs stdout_file.flush() @@ -993,8 +949,7 @@ def __init__(self): # oss_skip: pytest keeps complaining about mocking get_ipython module @pytest.mark.oss_skip -@pytest.mark.parametrize("v1", [True, False]) -async def test_flush_called_only_once(v1: bool) -> None: +async def test_flush_called_only_once() -> None: """Test that flush is called only once when ending an ipython cell""" mock_ipython = MockIPython() with unittest.mock.patch( @@ -1006,8 +961,8 @@ async def test_flush_called_only_once(v1: bool) -> None: "monarch._src.actor.logging.flush_all_proc_mesh_logs" ) as mock_flush: # Create 2 proc meshes with a large aggregation window - pm1 = spawn_procs_on_this_host(v1, per_host={"gpus": 2}) - _ = spawn_procs_on_this_host(v1, per_host={"gpus": 2}) + pm1 = this_host().spawn_procs(per_host={"gpus": 2}) + _ = this_host().spawn_procs(per_host={"gpus": 2}) # flush not yet called unless post_run_cell assert mock_flush.call_count == 0 assert mock_ipython.events.registers == 0 @@ -1021,9 +976,8 @@ async def test_flush_called_only_once(v1: bool) -> None: # oss_skip: pytest keeps complaining about mocking get_ipython module @pytest.mark.oss_skip -@pytest.mark.parametrize("v1", [True, False]) @pytest.mark.timeout(180) -async def test_flush_logs_ipython(v1: bool) -> None: +async def test_flush_logs_ipython() -> None: """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered.""" # Save original file descriptors original_stdout_fd = os.dup(1) # stdout @@ -1049,8 +1003,8 @@ async def test_flush_logs_ipython(v1: bool) -> None: ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True): # Make sure we can register and unregister callbacks for _ in range(3): - pm1 = spawn_procs_on_this_host(v1, per_host={"gpus": 2}) - pm2 = spawn_procs_on_this_host(v1, per_host={"gpus": 2}) + pm1 = this_host().spawn_procs(per_host={"gpus": 2}) + pm2 = this_host().spawn_procs(per_host={"gpus": 2}) am1 = pm1.spawn("printer", Printer) am2 = pm2.spawn("printer", Printer) @@ -1154,9 +1108,8 @@ async def test_flush_logs_fast_exit() -> None: ), process.stdout -@pytest.mark.parametrize("v1", [True, False]) @pytest.mark.timeout(60) -async def test_flush_on_disable_aggregation(v1: bool) -> None: +async def test_flush_on_disable_aggregation() -> None: """Test that logs are flushed when disabling aggregation. This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation." @@ -1177,7 +1130,7 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None: sys.stdout = stdout_file try: - pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2}) + pm = this_host().spawn_procs(per_host={"gpus": 2}) am = pm.spawn("printer", Printer) # Set a long aggregation window to ensure logs aren't flushed immediately @@ -1198,11 +1151,7 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None: for _ in range(5): await am.print.call("single log line") - if not v1: - await pm.stop() - else: - # Wait for > default aggregation window (3 secs) - await asyncio.sleep(5) + await pm.stop() # Flush all outputs stdout_file.flush() @@ -1248,15 +1197,14 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None: pass -@pytest.mark.parametrize("v1", [True, False]) @pytest.mark.timeout(120) -async def test_multiple_ongoing_flushes_no_deadlock(v1: bool) -> None: +async def test_multiple_ongoing_flushes_no_deadlock() -> None: """ The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked. Because now a flush call is purely sync, it is very easy to get into a deadlock. So we assert the last flush call will not get into such a state. """ - pm = spawn_procs_on_this_host(v1, per_host={"gpus": 4}) + pm = this_host().spawn_procs(per_host={"gpus": 4}) am = pm.spawn("printer", Printer) # Generate some logs that will be aggregated but not flushed immediately @@ -1279,9 +1227,8 @@ async def test_multiple_ongoing_flushes_no_deadlock(v1: bool) -> None: futures[-1].get() -@pytest.mark.parametrize("v1", [True, False]) @pytest.mark.timeout(60) -async def test_adjust_aggregation_window(v1: bool) -> None: +async def test_adjust_aggregation_window() -> None: """Test that the flush deadline is updated when the aggregation window is adjusted. This tests the corner case: "This can happen if the user has adjusted the aggregation window." @@ -1302,7 +1249,7 @@ async def test_adjust_aggregation_window(v1: bool) -> None: sys.stdout = stdout_file try: - pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2}) + pm = this_host().spawn_procs(per_host={"gpus": 2}) am = pm.spawn("printer", Printer) # Set a long aggregation window initially @@ -1320,11 +1267,7 @@ async def test_adjust_aggregation_window(v1: bool) -> None: for _ in range(3): await am.print.call("second batch of logs") - if not v1: - await pm.stop() - else: - # Wait for > aggregation window (2 secs) - await asyncio.sleep(4) + await pm.stop() # Flush all outputs stdout_file.flush()