Skip to content

Commit 0afea26

Browse files
committed
disagg-specific check
Signed-off-by: raayandhar <[email protected]>
1 parent 0b98591 commit 0afea26

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -874,8 +874,8 @@ def _executor_loop_pp(self):
874874

875875
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
876876
self._terminate_ctx_finished_requests()
877-
878-
if self.dist.pp_size > 1 and self.enable_kv_cache_reuse:
877+
878+
if self.dist.pp_size > 1 and self.enable_kv_cache_reuse and self.kv_cache_transceiver:
879879
self._sync_termination(prev_microbatch_id)
880880

881881
# march forward in microbatch slots
@@ -1598,7 +1598,7 @@ def _handle_errors(self,
15981598
self._enqueue_responses(error_responses.items())
15991599

16001600
def _terminate_request(self, request: LlmRequest):
1601-
if self.dist.pp_size > 1 and self.enable_kv_cache_reuse:
1601+
if self.dist.pp_size > 1 and self.enable_kv_cache_reuse and self.kv_cache_transceiver:
16021602
# If pp_size > 1 and enable_kv_cache_reuse, we need to sync termination across PP ranks
16031603
# otherwise, different ranks may have different KV cache blocks and a request may
16041604
# have different PrepopulatedPromptLen which leads to a NCCL hang.
@@ -1612,7 +1612,7 @@ def _terminate_request(self, request: LlmRequest):
16121612
state['ready_to_terminate'].add(self.dist.rank)
16131613
else:
16141614
self._free_resources_for_request(request)
1615-
1615+
16161616
def _free_resources_for_request(self, request: LlmRequest):
16171617
logger.debug(f"free resources for request {request.py_request_id}")
16181618
self.resource_manager.free_resources(request)
@@ -1773,17 +1773,18 @@ def _sync_termination(self, microbatch_id: int):
17731773
src=self.dist.prev_pp_rank,
17741774
tag=microbatch_id,
17751775
)
1776-
logger.debug(f"received remote state for microbatch {microbatch_id}, prev pp rank: {self.dist.prev_pp_rank} state {remote_state}")
1776+
logger.debug(
1777+
f"received remote state for microbatch {microbatch_id}, prev pp rank: {self.dist.prev_pp_rank} state {remote_state}"
1778+
)
17771779

17781780
if remote_state:
17791781
for req_id, state in remote_state.items():
17801782
local = self.pending_termination.get(req_id)
17811783
if local is None:
17821784
self.pending_termination[req_id] = {
1783-
'ready_to_terminate':
1784-
state.get('ready_to_terminate', set()),
1785-
'terminated':
1786-
state.get('terminated', set()),
1785+
'ready_to_terminate': state.get('ready_to_terminate',
1786+
set()),
1787+
'terminated': state.get('terminated', set()),
17871788
}
17881789
else:
17891790
for key in ('ready_to_terminate', 'terminated'):

0 commit comments

Comments
 (0)