diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 7d9b247c945..55173f52426 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -627,6 +627,12 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest], batch_state.sample_state.scheduled_requests), req_stats) def _executor_loop_cleanup(self): + # Unblock receiving processes. When second-last rank quits before last rank, + # last rank will never return from recv_object. + for req in self.send_handles: + if req is not None: + req.wait() + with self.response_cv: self.is_shutdown = True self.response_cv.notify_all() @@ -750,6 +756,7 @@ def _executor_loop_pp(self): sample_state = self._sample_async( scheduled_batch, batch_outputs) + assert sample_state is not None, "Sampling failed" sample_state.host.logits = logits_host self._update_request_states(scheduled_batch) @@ -801,6 +808,7 @@ def _executor_loop_pp(self): if not self.dist.is_second_last_pp_rank: if self.send_handles[prev_microbatch_id] is not None: self.send_handles[prev_microbatch_id].wait() + self.send_handles[prev_microbatch_id] = None needs_logits = ( self._need_return_logits(scheduled_batch) or (self._need_return_log_probs(scheduled_batch)