Skip to content

Commit 6affc99

Browse files
committed
finish all send requests before quitting pp event-loop to avoid mpi deadlock; synchronize sampler right after async calls to avoid hang
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 96bda14 commit 6affc99

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,12 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest],
627627
batch_state.sample_state.scheduled_requests), req_stats)
628628

629629
def _executor_loop_cleanup(self):
630+
# Unblock receiving processes. When second-last rank quits before last rank,
631+
# last rank will never return from recv_object.
632+
for req in self.send_handles:
633+
if req is not None:
634+
req.wait()
635+
630636
with self.response_cv:
631637
self.is_shutdown = True
632638
self.response_cv.notify_all()
@@ -750,8 +756,10 @@ def _executor_loop_pp(self):
750756

751757
sample_state = self._sample_async(
752758
scheduled_batch, batch_outputs)
759+
assert sample_state is not None, "Sampling failed"
753760
sample_state.host.logits = logits_host
754761
self._update_request_states(scheduled_batch)
762+
sample_state.sampler_event.synchronize()
755763

756764
if self.enable_iter_perf_stats:
757765
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,22 @@ def run_disaggregated_test(example_dir,
156156
run_env = env.copy()
157157
run_env["UCX_TLS"] = "^ib"
158158

159+
nsys_path = os.getenv("NSYS_PATH", None)
160+
nsys_file = os.getenv("NSYS_FILE", None)
161+
nsys_cmd = [
162+
"nsys",
163+
"profile",
164+
"--trace",
165+
"cuda,cublas,nvtx",
166+
"--output",
167+
nsys_file,
168+
"--force-overwrite=true",
169+
] if nsys_path and nsys_file else []
170+
159171
num_ranks, config_file = get_test_config(test_desc, example_dir,
160172
os.path.dirname(__file__))
161173

162-
workers_cmd = [
174+
workers_cmd = nsys_cmd + [
163175
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
164176
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
165177
config_file

0 commit comments

Comments
 (0)