3030 PythonMessage ,
3131 PythonMessageKind ,
3232)
33- from monarch ._rust_bindings .monarch_hyperactor .alloc import Alloc , AllocSpec
3433from monarch ._rust_bindings .monarch_hyperactor .mailbox import (
3534 PortId ,
3635 PortRef ,
3736 UndeliverableMessageEnvelope ,
3837)
3938from monarch ._rust_bindings .monarch_hyperactor .proc import ActorId
4039from monarch ._rust_bindings .monarch_hyperactor .pytokio import PythonTask
41- from monarch ._rust_bindings .monarch_hyperactor .shape import Extent
4240
4341from monarch ._src .actor .actor_mesh import ActorMesh , Channel , context , Port
44- from monarch ._src .actor .allocator import AllocHandle , ProcessAllocator
42+ from monarch ._src .actor .allocator import AllocHandle
4543from monarch ._src .actor .future import Future
4644from monarch ._src .actor .host_mesh import (
4745 create_local_host_mesh ,
4846 fake_in_process_host ,
4947 HostMesh ,
5048)
51- from monarch ._src .actor .proc_mesh import (
52- _get_bootstrap_args ,
53- get_or_spawn_controller ,
54- ProcMesh ,
55- )
49+ from monarch ._src .actor .proc_mesh import get_or_spawn_controller , ProcMesh
5650from monarch ._src .actor .v1 .host_mesh import (
57- _bootstrap_cmd ,
5851 fake_in_process_host as fake_in_process_host_v1 ,
5952 HostMesh as HostMeshV1 ,
6053 this_host as this_host_v1 ,
@@ -473,7 +466,7 @@ async def no_more(self) -> None:
473466
474467
475468@pytest .mark .parametrize ("v1" , [True , False ])
476- @pytest .mark .timeout (60 )
469+ @pytest .mark .timeout (30 )
477470async def test_async_concurrency (v1 : bool ):
478471 """Test that async endpoints will be processed concurrently."""
479472 pm = spawn_procs_on_this_host (v1 , {})
@@ -610,9 +603,8 @@ def _handle_undeliverable_message(
610603 return True
611604
612605
613- @pytest .mark .parametrize ("v1" , [True , False ])
614606@pytest .mark .timeout (60 )
615- async def test_actor_log_streaming (v1 : bool ) -> None :
607+ async def test_actor_log_streaming () -> None :
616608 # Save original file descriptors
617609 original_stdout_fd = os .dup (1 ) # stdout
618610 original_stderr_fd = os .dup (2 ) # stderr
@@ -639,7 +631,7 @@ async def test_actor_log_streaming(v1: bool) -> None:
639631 sys .stderr = stderr_file
640632
641633 try :
642- pm = spawn_procs_on_this_host (v1 , per_host = {"gpus" : 2 })
634+ pm = spawn_procs_on_this_host (v1 = False , per_host = {"gpus" : 2 })
643635 am = pm .spawn ("printer" , Printer )
644636
645637 # Disable streaming logs to client
@@ -679,10 +671,7 @@ async def test_actor_log_streaming(v1: bool) -> None:
679671 await am .print .call ("has print streaming too" )
680672 await am .log .call ("has log streaming as level matched" )
681673
682- if not v1 :
683- await pm .stop ()
684- else :
685- await asyncio .sleep (1 )
674+ await pm .stop ()
686675
687676 # Flush all outputs
688677 stdout_file .flush ()
@@ -763,9 +752,8 @@ async def test_actor_log_streaming(v1: bool) -> None:
763752 pass
764753
765754
766- @pytest .mark .parametrize ("v1" , [True , False ])
767755@pytest .mark .timeout (120 )
768- async def test_alloc_based_log_streaming (v1 : bool ) -> None :
756+ async def test_alloc_based_log_streaming () -> None :
769757 """Test both AllocHandle.stream_logs = False and True cases."""
770758
771759 async def test_stream_logs_case (stream_logs : bool , test_name : str ) -> None :
@@ -782,50 +770,23 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
782770
783771 try :
784772 # Create proc mesh with custom stream_logs setting
785- if not v1 :
786- host_mesh = create_local_host_mesh ()
787- alloc_handle = host_mesh ._alloc (hosts = 1 , gpus = 2 )
788-
789- # Override the stream_logs setting
790- custom_alloc_handle = AllocHandle (
791- alloc_handle ._hy_alloc , alloc_handle ._extent , stream_logs
792- )
773+ host_mesh = create_local_host_mesh ()
774+ alloc_handle = host_mesh ._alloc (hosts = 1 , gpus = 2 )
793775
794- pm = ProcMesh .from_alloc (custom_alloc_handle )
795- else :
796-
797- class ProcessAllocatorStreamLogs (ProcessAllocator ):
798- def allocate_nonblocking (
799- self , spec : AllocSpec
800- ) -> PythonTask [Alloc ]:
801- return super ().allocate_nonblocking (spec )
802-
803- def _stream_logs (self ) -> bool :
804- return stream_logs
805-
806- alloc = ProcessAllocatorStreamLogs (* _get_bootstrap_args ())
807-
808- host_mesh = HostMeshV1 .allocate_nonblocking (
809- "host" ,
810- Extent (["hosts" ], [1 ]),
811- alloc ,
812- bootstrap_cmd = _bootstrap_cmd (),
813- )
814-
815- pm = host_mesh .spawn_procs (name = "proc" , per_host = {"gpus" : 2 })
776+ # Override the stream_logs setting
777+ custom_alloc_handle = AllocHandle (
778+ alloc_handle ._hy_alloc , alloc_handle ._extent , stream_logs
779+ )
816780
781+ pm = ProcMesh .from_alloc (custom_alloc_handle )
817782 am = pm .spawn ("printer" , Printer )
818783
819784 await pm .initialized
820785
821786 for _ in range (5 ):
822787 await am .print .call (f"{ test_name } print streaming" )
823788
824- if not v1 :
825- await pm .stop ()
826- else :
827- # Wait for at least the aggregation window (3 seconds)
828- await asyncio .sleep (5 )
789+ await pm .stop ()
829790
830791 # Flush all outputs
831792 stdout_file .flush ()
@@ -849,18 +810,18 @@ def _stream_logs(self) -> bool:
849810 # When stream_logs=False, logs should not be streamed to client
850811 assert not re .search (
851812 rf"similar log lines.*{ test_name } print streaming" , stdout_content
852- ), f"stream_logs=False case: { stdout_content } "
813+ ), f"stream_logs=True case: { stdout_content } "
853814 assert re .search (
854815 rf"{ test_name } print streaming" , stdout_content
855- ), f"stream_logs=False case: { stdout_content } "
816+ ), f"stream_logs=True case: { stdout_content } "
856817 else :
857818 # When stream_logs=True, logs should be streamed to client (no aggregation by default)
858819 assert re .search (
859820 rf"similar log lines.*{ test_name } print streaming" , stdout_content
860- ), f"stream_logs=True case: { stdout_content } "
821+ ), f"stream_logs=False case: { stdout_content } "
861822 assert not re .search (
862823 rf"\[[0-9]\]{ test_name } print streaming" , stdout_content
863- ), f"stream_logs=True case: { stdout_content } "
824+ ), f"stream_logs=False case: { stdout_content } "
864825
865826 finally :
866827 # Ensure file descriptors are restored even if something goes wrong
@@ -875,9 +836,8 @@ def _stream_logs(self) -> bool:
875836 await test_stream_logs_case (True , "stream_logs_true" )
876837
877838
878- @pytest .mark .parametrize ("v1" , [True , False ])
879839@pytest .mark .timeout (60 )
880- async def test_logging_option_defaults (v1 : bool ) -> None :
840+ async def test_logging_option_defaults () -> None :
881841 # Save original file descriptors
882842 original_stdout_fd = os .dup (1 ) # stdout
883843 original_stderr_fd = os .dup (2 ) # stderr
@@ -904,18 +864,14 @@ async def test_logging_option_defaults(v1: bool) -> None:
904864 sys .stderr = stderr_file
905865
906866 try :
907- pm = spawn_procs_on_this_host (v1 , per_host = {"gpus" : 2 })
867+ pm = spawn_procs_on_this_host (v1 = False , per_host = {"gpus" : 2 })
908868 am = pm .spawn ("printer" , Printer )
909869
910870 for _ in range (5 ):
911871 await am .print .call ("print streaming" )
912872 await am .log .call ("log streaming" )
913873
914- if not v1 :
915- await pm .stop ()
916- else :
917- # Wait for > default aggregation window (3 seconds)
918- await asyncio .sleep (5 )
874+ await pm .stop ()
919875
920876 # Flush all outputs
921877 stdout_file .flush ()
@@ -993,8 +949,7 @@ def __init__(self):
993949
994950# oss_skip: pytest keeps complaining about mocking get_ipython module
995951@pytest .mark .oss_skip
996- @pytest .mark .parametrize ("v1" , [True , False ])
997- async def test_flush_called_only_once (v1 : bool ) -> None :
952+ async def test_flush_called_only_once () -> None :
998953 """Test that flush is called only once when ending an ipython cell"""
999954 mock_ipython = MockIPython ()
1000955 with unittest .mock .patch (
@@ -1006,8 +961,8 @@ async def test_flush_called_only_once(v1: bool) -> None:
1006961 "monarch._src.actor.logging.flush_all_proc_mesh_logs"
1007962 ) as mock_flush :
1008963 # Create 2 proc meshes with a large aggregation window
1009- pm1 = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1010- _ = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
964+ pm1 = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
965+ _ = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
1011966 # flush not yet called unless post_run_cell
1012967 assert mock_flush .call_count == 0
1013968 assert mock_ipython .events .registers == 0
@@ -1021,9 +976,8 @@ async def test_flush_called_only_once(v1: bool) -> None:
1021976
1022977# oss_skip: pytest keeps complaining about mocking get_ipython module
1023978@pytest .mark .oss_skip
1024- @pytest .mark .parametrize ("v1" , [True , False ])
1025979@pytest .mark .timeout (180 )
1026- async def test_flush_logs_ipython (v1 : bool ) -> None :
980+ async def test_flush_logs_ipython () -> None :
1027981 """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
1028982 # Save original file descriptors
1029983 original_stdout_fd = os .dup (1 ) # stdout
@@ -1049,8 +1003,8 @@ async def test_flush_logs_ipython(v1: bool) -> None:
10491003 ), unittest .mock .patch ("monarch._src.actor.logging.IN_IPYTHON" , True ):
10501004 # Make sure we can register and unregister callbacks
10511005 for _ in range (3 ):
1052- pm1 = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1053- pm2 = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1006+ pm1 = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
1007+ pm2 = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
10541008 am1 = pm1 .spawn ("printer" , Printer )
10551009 am2 = pm2 .spawn ("printer" , Printer )
10561010
@@ -1154,9 +1108,8 @@ async def test_flush_logs_fast_exit() -> None:
11541108 ), process .stdout
11551109
11561110
1157- @pytest .mark .parametrize ("v1" , [True , False ])
11581111@pytest .mark .timeout (60 )
1159- async def test_flush_on_disable_aggregation (v1 : bool ) -> None :
1112+ async def test_flush_on_disable_aggregation () -> None :
11601113 """Test that logs are flushed when disabling aggregation.
11611114
11621115 This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation."
@@ -1177,7 +1130,7 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None:
11771130 sys .stdout = stdout_file
11781131
11791132 try :
1180- pm = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1133+ pm = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
11811134 am = pm .spawn ("printer" , Printer )
11821135
11831136 # Set a long aggregation window to ensure logs aren't flushed immediately
@@ -1198,11 +1151,7 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None:
11981151 for _ in range (5 ):
11991152 await am .print .call ("single log line" )
12001153
1201- if not v1 :
1202- await pm .stop ()
1203- else :
1204- # Wait for > default aggregation window (3 secs)
1205- await asyncio .sleep (5 )
1154+ await pm .stop ()
12061155
12071156 # Flush all outputs
12081157 stdout_file .flush ()
@@ -1248,15 +1197,14 @@ async def test_flush_on_disable_aggregation(v1: bool) -> None:
12481197 pass
12491198
12501199
1251- @pytest .mark .parametrize ("v1" , [True , False ])
12521200@pytest .mark .timeout (120 )
1253- async def test_multiple_ongoing_flushes_no_deadlock (v1 : bool ) -> None :
1201+ async def test_multiple_ongoing_flushes_no_deadlock () -> None :
12541202 """
12551203 The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked.
12561204 Because now a flush call is purely sync, it is very easy to get into a deadlock.
12571205 So we assert the last flush call will not get into such a state.
12581206 """
1259- pm = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 4 })
1207+ pm = this_host (). spawn_procs ( per_host = {"gpus" : 4 })
12601208 am = pm .spawn ("printer" , Printer )
12611209
12621210 # Generate some logs that will be aggregated but not flushed immediately
@@ -1279,9 +1227,8 @@ async def test_multiple_ongoing_flushes_no_deadlock(v1: bool) -> None:
12791227 futures [- 1 ].get ()
12801228
12811229
1282- @pytest .mark .parametrize ("v1" , [True , False ])
12831230@pytest .mark .timeout (60 )
1284- async def test_adjust_aggregation_window (v1 : bool ) -> None :
1231+ async def test_adjust_aggregation_window () -> None :
12851232 """Test that the flush deadline is updated when the aggregation window is adjusted.
12861233
12871234 This tests the corner case: "This can happen if the user has adjusted the aggregation window."
@@ -1302,7 +1249,7 @@ async def test_adjust_aggregation_window(v1: bool) -> None:
13021249 sys .stdout = stdout_file
13031250
13041251 try :
1305- pm = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1252+ pm = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
13061253 am = pm .spawn ("printer" , Printer )
13071254
13081255 # Set a long aggregation window initially
@@ -1320,11 +1267,7 @@ async def test_adjust_aggregation_window(v1: bool) -> None:
13201267 for _ in range (3 ):
13211268 await am .print .call ("second batch of logs" )
13221269
1323- if not v1 :
1324- await pm .stop ()
1325- else :
1326- # Wait for > aggregation window (2 secs)
1327- await asyncio .sleep (4 )
1270+ await pm .stop ()
13281271
13291272 # Flush all outputs
13301273 stdout_file .flush ()
0 commit comments