diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index e70536e0199..d0a4fd59b26 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -17,7 +17,12 @@ get_logprob_from_pp_outputs, ) from sglang.srt.model_executor.forward_batch_info import PPProxyTensors -from sglang.srt.utils import DynamicGradMode, broadcast_pyobj, point_to_point_pyobj, require_mlp_sync +from sglang.srt.utils import ( + DynamicGradMode, + broadcast_pyobj, + point_to_point_pyobj, + require_mlp_sync, +) logger = logging.getLogger(__name__) @@ -183,23 +188,13 @@ def _pp_send_output_to_next_stage( def _pp_send_recv_and_preprocess_output_tensors( self: Scheduler, - next_first_rank_mb_id: int, next_mb_id: int, mbs: List[ScheduleBatch], mb_metadata: List[PPBatchMetadata], - last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]], - pp_outputs: PPProxyTensors | None, ) -> Tuple[PPProxyTensors, List[P2PWork], torch.cuda.Event]: next_pp_outputs = None d2h_event = None batch_result = None - send_output_work = self._pp_send_output_to_next_stage( - next_first_rank_mb_id, - mbs, - last_rank_comm_queue, - pp_outputs, - ) - if mbs[next_mb_id] is not None: with torch.profiler.record_function("recv_res_dict_from_prev_stage"): next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage()) @@ -211,7 +206,7 @@ def _pp_send_recv_and_preprocess_output_tensors( d2h_event = torch.cuda.Event() d2h_event.record(torch.cuda.current_stream()) - return next_pp_outputs, batch_result, d2h_event, send_output_work + return next_pp_outputs, batch_result, d2h_event def _pp_launch_batch( self: Scheduler, @@ -301,49 +296,45 @@ def event_loop_pp(self: Scheduler): mbs[mb_id] = self.get_next_batch_to_run() self.running_mbs[mb_id] = self.running_batch self.cur_batch: Optional[ScheduleBatch] = mbs[mb_id] - if self.cur_batch: - server_is_idle = False - pp_proxy_tensors = self._pp_recv_proxy_tensors() - next_pp_outputs = None - next_batch_result = None - d2h_event = None - if self.server_args.pp_async_batch_depth > 0: + if ( + self.server_args.pp_async_batch_depth > 0 + or not self.pp_group.is_last_rank + ): self._pp_commit_comm_work(work=send_output_work) - next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( - self._pp_send_recv_and_preprocess_output_tensors( - next_first_rank_mb_id, - next_mb_id, - mbs, - mb_metadata, - last_rank_comm_queue, - pp_outputs, - ) + send_output_work = self._pp_send_output_to_next_stage( + next_first_rank_mb_id, + mbs, + last_rank_comm_queue, + pp_outputs, ) - self._pp_commit_comm_work(send_proxy_work) if self.cur_batch: + server_is_idle = False + pp_proxy_tensors = self._pp_recv_proxy_tensors() result, event = self._pp_launch_batch( mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue ) - if self.server_args.pp_async_batch_depth == 0: + if ( + self.server_args.pp_async_batch_depth == 0 + and self.pp_group.is_last_rank + ): self._pp_commit_comm_work(work=send_output_work) - next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( - self._pp_send_recv_and_preprocess_output_tensors( - next_first_rank_mb_id, - next_mb_id, - mbs, - mb_metadata, - last_rank_comm_queue, - pp_outputs, - ) + send_output_work = self._pp_send_output_to_next_stage( + next_first_rank_mb_id, + mbs, + last_rank_comm_queue, + pp_outputs, ) - if mbs[next_mb_id] is not None: - d2h_event.synchronize() - with torch.profiler.record_function("process_batch_result"): - self._pp_process_batch_result( - mbs[next_mb_id], - next_batch_result, - ) - last_mbs[next_mb_id] = mbs[next_mb_id] + next_pp_outputs = None + next_batch_result = None + d2h_event = None + next_pp_outputs, next_batch_result, d2h_event = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_mb_id, + mbs, + mb_metadata, + ) + ) + self._pp_commit_comm_work(send_proxy_work) if not self.pp_group.is_last_rank: if self.cur_batch: torch.cuda.current_stream().wait_event(event) @@ -354,6 +345,14 @@ def event_loop_pp(self: Scheduler): result.pp_hidden_states_proxy_tensors.tensors, async_send=True, ) + if mbs[next_mb_id] is not None: + d2h_event.synchronize() + with torch.profiler.record_function("process_batch_result"): + self._pp_process_batch_result( + mbs[next_mb_id], + next_batch_result, + ) + last_mbs[next_mb_id] = mbs[next_mb_id] # if self.delayed_weight_sync_fn: # self.delayed_weight_sync_fn()