Skip to content

Commit b4ba693

Browse files
committed
[monarch] Revert D83360166 and D83489725
Pull Request resolved: #1408 Disable v1 log forwarding with LoggingManager and reenable @ShayneFletcher's auto-spawn mechanism in an attempt to get github CI stable again. ghstack-source-id: 314042115 @exported-using-ghexport Differential Revision: [D83777450](https://our.internmc.facebook.com/intern/diff/D83777450/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D83777450/)!
1 parent e5ca0a9 commit b4ba693

File tree

3 files changed

+39
-96
lines changed

3 files changed

+39
-96
lines changed

python/monarch/_src/actor/logging.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ def flush_all_proc_mesh_logs(v1: bool = False) -> None:
5050
from monarch._src.actor.v1.proc_mesh import get_active_proc_meshes
5151

5252
for pm in get_active_proc_meshes():
53-
if pm._logging_manager._logging_mesh_client is not None:
54-
pm._logging_manager.flush()
53+
pm._logging_manager.flush()
5554

5655

5756
class LoggingManager:

python/monarch/_src/actor/v1/proc_mesh.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ async def task(
200200
) -> HyProcMesh:
201201
hy_proc_mesh = await hy_proc_mesh_task
202202

203-
await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client)
203+
# FIXME: Fix log forwarding.
204+
# await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client)
204205

205206
if setup_actor is not None:
206207
await setup_actor.setup.call()

python/tests/test_python_actors.py

Lines changed: 36 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -30,31 +30,24 @@
3030
PythonMessage,
3131
PythonMessageKind,
3232
)
33-
from monarch._rust_bindings.monarch_hyperactor.alloc import Alloc, AllocSpec
3433
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
3534
PortId,
3635
PortRef,
3736
UndeliverableMessageEnvelope,
3837
)
3938
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
4039
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
41-
from monarch._rust_bindings.monarch_hyperactor.shape import Extent
4240

4341
from monarch._src.actor.actor_mesh import ActorMesh, Channel, context, Port
44-
from monarch._src.actor.allocator import AllocHandle, ProcessAllocator
42+
from monarch._src.actor.allocator import AllocHandle
4543
from monarch._src.actor.future import Future
4644
from monarch._src.actor.host_mesh import (
4745
create_local_host_mesh,
4846
fake_in_process_host,
4947
HostMesh,
5048
)
51-
from monarch._src.actor.proc_mesh import (
52-
_get_bootstrap_args,
53-
get_or_spawn_controller,
54-
ProcMesh,
55-
)
49+
from monarch._src.actor.proc_mesh import get_or_spawn_controller, ProcMesh
5650
from monarch._src.actor.v1.host_mesh import (
57-
_bootstrap_cmd,
5851
fake_in_process_host as fake_in_process_host_v1,
5952
HostMesh as HostMeshV1,
6053
this_host as this_host_v1,
@@ -473,7 +466,7 @@ async def no_more(self) -> None:
473466

474467

475468
@pytest.mark.parametrize("v1", [True, False])
476-
@pytest.mark.timeout(60)
469+
@pytest.mark.timeout(30)
477470
async def test_async_concurrency(v1: bool):
478471
"""Test that async endpoints will be processed concurrently."""
479472
pm = spawn_procs_on_this_host(v1, {})
@@ -610,9 +603,8 @@ def _handle_undeliverable_message(
610603
return True
611604

612605

613-
@pytest.mark.parametrize("v1", [True, False])
614606
@pytest.mark.timeout(60)
615-
async def test_actor_log_streaming(v1: bool) -> None:
607+
async def test_actor_log_streaming() -> None:
616608
# Save original file descriptors
617609
original_stdout_fd = os.dup(1) # stdout
618610
original_stderr_fd = os.dup(2) # stderr
@@ -639,7 +631,7 @@ async def test_actor_log_streaming(v1: bool) -> None:
639631
sys.stderr = stderr_file
640632

641633
try:
642-
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
634+
pm = spawn_procs_on_this_host(v1=False, per_host={"gpus": 2})
643635
am = pm.spawn("printer", Printer)
644636

645637
# Disable streaming logs to client
@@ -679,10 +671,7 @@ async def test_actor_log_streaming(v1: bool) -> None:
679671
await am.print.call("has print streaming too")
680672
await am.log.call("has log streaming as level matched")
681673

682-
if not v1:
683-
await pm.stop()
684-
else:
685-
await asyncio.sleep(1)
674+
await pm.stop()
686675

687676
# Flush all outputs
688677
stdout_file.flush()
@@ -763,9 +752,8 @@ async def test_actor_log_streaming(v1: bool) -> None:
763752
pass
764753

765754

766-
@pytest.mark.parametrize("v1", [True, False])
767755
@pytest.mark.timeout(120)
768-
async def test_alloc_based_log_streaming(v1: bool) -> None:
756+
async def test_alloc_based_log_streaming() -> None:
769757
"""Test both AllocHandle.stream_logs = False and True cases."""
770758

771759
async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
@@ -782,50 +770,23 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
782770

783771
try:
784772
# Create proc mesh with custom stream_logs setting
785-
if not v1:
786-
host_mesh = create_local_host_mesh()
787-
alloc_handle = host_mesh._alloc(hosts=1, gpus=2)
788-
789-
# Override the stream_logs setting
790-
custom_alloc_handle = AllocHandle(
791-
alloc_handle._hy_alloc, alloc_handle._extent, stream_logs
792-
)
773+
host_mesh = create_local_host_mesh()
774+
alloc_handle = host_mesh._alloc(hosts=1, gpus=2)
793775

794-
pm = ProcMesh.from_alloc(custom_alloc_handle)
795-
else:
796-
797-
class ProcessAllocatorStreamLogs(ProcessAllocator):
798-
def allocate_nonblocking(
799-
self, spec: AllocSpec
800-
) -> PythonTask[Alloc]:
801-
return super().allocate_nonblocking(spec)
802-
803-
def _stream_logs(self) -> bool:
804-
return stream_logs
805-
806-
alloc = ProcessAllocatorStreamLogs(*_get_bootstrap_args())
807-
808-
host_mesh = HostMeshV1.allocate_nonblocking(
809-
"host",
810-
Extent(["hosts"], [1]),
811-
alloc,
812-
bootstrap_cmd=_bootstrap_cmd(),
813-
)
814-
815-
pm = host_mesh.spawn_procs(name="proc", per_host={"gpus": 2})
776+
# Override the stream_logs setting
777+
custom_alloc_handle = AllocHandle(
778+
alloc_handle._hy_alloc, alloc_handle._extent, stream_logs
779+
)
816780

781+
pm = ProcMesh.from_alloc(custom_alloc_handle)
817782
am = pm.spawn("printer", Printer)
818783

819784
await pm.initialized
820785

821786
for _ in range(5):
822787
await am.print.call(f"{test_name} print streaming")
823788

824-
if not v1:
825-
await pm.stop()
826-
else:
827-
# Wait for at least the aggregation window (3 seconds)
828-
await asyncio.sleep(5)
789+
await pm.stop()
829790

830791
# Flush all outputs
831792
stdout_file.flush()
@@ -849,18 +810,18 @@ def _stream_logs(self) -> bool:
849810
# When stream_logs=False, logs should not be streamed to client
850811
assert not re.search(
851812
rf"similar log lines.*{test_name} print streaming", stdout_content
852-
), f"stream_logs=False case: {stdout_content}"
813+
), f"stream_logs=True case: {stdout_content}"
853814
assert re.search(
854815
rf"{test_name} print streaming", stdout_content
855-
), f"stream_logs=False case: {stdout_content}"
816+
), f"stream_logs=True case: {stdout_content}"
856817
else:
857818
# When stream_logs=True, logs should be streamed to client (no aggregation by default)
858819
assert re.search(
859820
rf"similar log lines.*{test_name} print streaming", stdout_content
860-
), f"stream_logs=True case: {stdout_content}"
821+
), f"stream_logs=False case: {stdout_content}"
861822
assert not re.search(
862823
rf"\[[0-9]\]{test_name} print streaming", stdout_content
863-
), f"stream_logs=True case: {stdout_content}"
824+
), f"stream_logs=False case: {stdout_content}"
864825

865826
finally:
866827
# Ensure file descriptors are restored even if something goes wrong
@@ -875,9 +836,8 @@ def _stream_logs(self) -> bool:
875836
await test_stream_logs_case(True, "stream_logs_true")
876837

877838

878-
@pytest.mark.parametrize("v1", [True, False])
879839
@pytest.mark.timeout(60)
880-
async def test_logging_option_defaults(v1: bool) -> None:
840+
async def test_logging_option_defaults() -> None:
881841
# Save original file descriptors
882842
original_stdout_fd = os.dup(1) # stdout
883843
original_stderr_fd = os.dup(2) # stderr
@@ -904,18 +864,14 @@ async def test_logging_option_defaults(v1: bool) -> None:
904864
sys.stderr = stderr_file
905865

906866
try:
907-
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
867+
pm = spawn_procs_on_this_host(v1=False, per_host={"gpus": 2})
908868
am = pm.spawn("printer", Printer)
909869

910870
for _ in range(5):
911871
await am.print.call("print streaming")
912872
await am.log.call("log streaming")
913873

914-
if not v1:
915-
await pm.stop()
916-
else:
917-
# Wait for > default aggregation window (3 seconds)
918-
await asyncio.sleep(5)
874+
await pm.stop()
919875

920876
# Flush all outputs
921877
stdout_file.flush()
@@ -993,8 +949,7 @@ def __init__(self):
993949

994950
# oss_skip: pytest keeps complaining about mocking get_ipython module
995951
@pytest.mark.oss_skip
996-
@pytest.mark.parametrize("v1", [True, False])
997-
async def test_flush_called_only_once(v1: bool) -> None:
952+
async def test_flush_called_only_once() -> None:
998953
"""Test that flush is called only once when ending an ipython cell"""
999954
mock_ipython = MockIPython()
1000955
with unittest.mock.patch(
@@ -1006,8 +961,8 @@ async def test_flush_called_only_once(v1: bool) -> None:
1006961
"monarch._src.actor.logging.flush_all_proc_mesh_logs"
1007962
) as mock_flush:
1008963
# Create 2 proc meshes with a large aggregation window
1009-
pm1 = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
1010-
_ = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
964+
pm1 = this_host().spawn_procs(per_host={"gpus": 2})
965+
_ = this_host().spawn_procs(per_host={"gpus": 2})
1011966
# flush not yet called unless post_run_cell
1012967
assert mock_flush.call_count == 0
1013968
assert mock_ipython.events.registers == 0
@@ -1021,9 +976,8 @@ async def test_flush_called_only_once(v1: bool) -> None:
1021976

1022977
# oss_skip: pytest keeps complaining about mocking get_ipython module
1023978
@pytest.mark.oss_skip
1024-
@pytest.mark.parametrize("v1", [True, False])
1025979
@pytest.mark.timeout(180)
1026-
async def test_flush_logs_ipython(v1: bool) -> None:
980+
async def test_flush_logs_ipython() -> None:
1027981
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
1028982
# Save original file descriptors
1029983
original_stdout_fd = os.dup(1) # stdout
@@ -1049,8 +1003,8 @@ async def test_flush_logs_ipython(v1: bool) -> None:
10491003
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
10501004
# Make sure we can register and unregister callbacks
10511005
for _ in range(3):
1052-
pm1 = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
1053-
pm2 = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
1006+
pm1 = this_host().spawn_procs(per_host={"gpus": 2})
1007+
pm2 = this_host().spawn_procs(per_host={"gpus": 2})
10541008
am1 = pm1.spawn("printer", Printer)
10551009
am2 = pm2.spawn("printer", Printer)
10561010

@@ -1154,9 +1108,8 @@ async def test_flush_logs_fast_exit() -> None:
11541108
), process.stdout
11551109

11561110

1157-
@pytest.mark.parametrize("v1", [True, False])
11581111
@pytest.mark.timeout(60)
1159-
async def test_flush_on_disable_aggregation(v1: bool) -> None:
1112+
async def test_flush_on_disable_aggregation() -> None:
11601113
"""Test that logs are flushed when disabling aggregation.
11611114
11621115
This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation."
@@ -1177,7 +1130,7 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None:
11771130
sys.stdout = stdout_file
11781131

11791132
try:
1180-
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
1133+
pm = this_host().spawn_procs(per_host={"gpus": 2})
11811134
am = pm.spawn("printer", Printer)
11821135

11831136
# Set a long aggregation window to ensure logs aren't flushed immediately
@@ -1198,11 +1151,7 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None:
11981151
for _ in range(5):
11991152
await am.print.call("single log line")
12001153

1201-
if not v1:
1202-
await pm.stop()
1203-
else:
1204-
# Wait for > default aggregation window (3 secs)
1205-
await asyncio.sleep(5)
1154+
await pm.stop()
12061155

12071156
# Flush all outputs
12081157
stdout_file.flush()
@@ -1248,15 +1197,14 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None:
12481197
pass
12491198

12501199

1251-
@pytest.mark.parametrize("v1", [True, False])
12521200
@pytest.mark.timeout(120)
1253-
async def test_multiple_ongoing_flushes_no_deadlock(v1: bool) -> None:
1201+
async def test_multiple_ongoing_flushes_no_deadlock() -> None:
12541202
"""
12551203
The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked.
12561204
Because now a flush call is purely sync, it is very easy to get into a deadlock.
12571205
So we assert the last flush call will not get into such a state.
12581206
"""
1259-
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 4})
1207+
pm = this_host().spawn_procs(per_host={"gpus": 4})
12601208
am = pm.spawn("printer", Printer)
12611209

12621210
# Generate some logs that will be aggregated but not flushed immediately
@@ -1279,9 +1227,8 @@ async def test_multiple_ongoing_flushes_no_deadlock(v1: bool) -> None:
12791227
futures[-1].get()
12801228

12811229

1282-
@pytest.mark.parametrize("v1", [True, False])
12831230
@pytest.mark.timeout(60)
1284-
async def test_adjust_aggregation_window(v1: bool) -> None:
1231+
async def test_adjust_aggregation_window() -> None:
12851232
"""Test that the flush deadline is updated when the aggregation window is adjusted.
12861233
12871234
This tests the corner case: "This can happen if the user has adjusted the aggregation window."
@@ -1302,7 +1249,7 @@ async def test_adjust_aggregation_window(v1: bool) -> None:
13021249
sys.stdout = stdout_file
13031250

13041251
try:
1305-
pm = spawn_procs_on_this_host(v1, per_host={"gpus": 2})
1252+
pm = this_host().spawn_procs(per_host={"gpus": 2})
13061253
am = pm.spawn("printer", Printer)
13071254

13081255
# Set a long aggregation window initially
@@ -1320,11 +1267,7 @@ async def test_adjust_aggregation_window(v1: bool) -> None:
13201267
for _ in range(3):
13211268
await am.print.call("second batch of logs")
13221269

1323-
if not v1:
1324-
await pm.stop()
1325-
else:
1326-
# Wait for > aggregation window (2 secs)
1327-
await asyncio.sleep(4)
1270+
await pm.stop()
13281271

13291272
# Flush all outputs
13301273
stdout_file.flush()

0 commit comments

Comments
 (0)