Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 40 additions & 30 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,34 @@ class DataResponder::Impl
}
}

void sendResponse(std::vector<size_t> const& blockHashes, std::map<RequestIdType, Response>::iterator it)
{
auto reqId = mCurrentRequest.value();
auto count = --mRemainSendCount[reqId];
TLLM_CHECK(count >= 0);
if (count == 0)
{
mRemainSendCount.erase(reqId);

// TODO(zhengd): pass the hashes directly instead of update llmRequest
auto llmRequest = it->second.mRequest;
llmRequest->setRequestedBlockHashes(std::move(blockHashes));

if (common::getEnvParallelCacheSend())
{
// TODO: Use a thread pool and check for thread safety.
std::thread(&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
.detach();
}
else
{
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
}
removeResponse(it);
}
mCurrentRequest = std::nullopt;
}

void response() noexcept
{
try
Expand Down Expand Up @@ -237,40 +265,22 @@ class DataResponder::Impl
auto it = getCurrentResponse();
if (it != mReadyResponses.end())
{
auto reqId = mCurrentRequest.value();
auto count = --mRemainSendCount[reqId];
TLLM_CHECK(count >= 0);
if (count == 0)
sendResponse(blockHashes, it);
}
else
{
auto it = getCurrentResponse();
while (it == mReadyResponses.end())
{
mRemainSendCount.erase(reqId);

// TODO(zhengd): pass the hashes directly instead of update llmRequest
auto llmRequest = it->second.mRequest;
llmRequest->setRequestedBlockHashes(std::move(blockHashes));

if (common::getEnvParallelCacheSend())
{
// TODO: Use a thread pool and check for thread safety.
std::thread(
&DataResponder::Impl::sendAndRemoveResponse, this, it->first, std::move(it->second))
.detach();
}
else
std::unique_lock lk(mCondMutex);
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
if (mTerminate)
{
DataResponder::Impl::sendAndRemoveResponse(it->first, std::move(it->second));
break;
}
removeResponse(it);
it = getCurrentResponse();
}
mCurrentRequest = std::nullopt;
}
else
{
TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(),
"This executor does not have a prepared KV cache for request ID: %zu, and the "
"mReadyResponses size is: %zu. mpi rank :%d ",
mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank());
std::unique_lock lk(mCondMutex);
mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); });
sendResponse(blockHashes, it);
}
}
}
Expand Down
137 changes: 135 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
# Set to a path to save detailed tracing of PyTorch operations.
PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"

# Unique tag base to avoid collisions with token/logits comms
TERMINATION_COMM_TAG_BASE = 20000


@functools.cache
def _load_iteration_indexes(env_var: str):
Expand Down Expand Up @@ -208,6 +211,7 @@ def __init__(self,
self.kv_cache_manager = self.resource_manager.resource_managers.get(
ResourceManagerType.KV_CACHE_MANAGER)
self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0
self.enable_kv_cache_reuse = self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse

self.max_input_len = max_input_len
# _executor_loop private data
Expand Down Expand Up @@ -259,6 +263,13 @@ def __init__(self,
self.gather_all_responses = False

self.kv_cache_transceiver = kv_cache_transceiver

# Initialize disagg PP termination handler if needed
self._disagg_pp_termination_handler = None
if self.dist.pp_size > 1 and self.enable_kv_cache_reuse and self.kv_cache_transceiver:
self._disagg_pp_termination_handler = DisaggPPTerminationHandler(
self.num_micro_batches, self.dist)

if self.dist.pp_size > 1:
self.event_loop = self._executor_loop_pp
else:
Expand Down Expand Up @@ -718,6 +729,14 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest],
batch_state.sample_state.scheduled_requests), req_stats)

def _executor_loop_cleanup(self):

for h in self.send_handles:
if h is not None:
h.wait()

if self._disagg_pp_termination_handler is not None:
self._disagg_pp_termination_handler.cleanup()

with self.response_cv:
self.is_shutdown = True
self.response_cv.notify_all()
Expand Down Expand Up @@ -826,6 +845,7 @@ def _executor_loop_pp(self):

sample_state = self._sample_async(
scheduled_batch, batch_outputs)
assert sample_state is not None, "Sampling failed"
self._update_request_states(scheduled_batch)

if self.enable_iter_perf_stats:
Expand Down Expand Up @@ -905,6 +925,12 @@ def _executor_loop_pp(self):
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._terminate_ctx_finished_requests()

if self._disagg_pp_termination_handler is not None:
requests_to_terminate = self._disagg_pp_termination_handler.sync(
prev_microbatch_id)
for req in requests_to_terminate:
self._do_terminate_request(req)

# march forward in microbatch slots
microbatch_id = (microbatch_id + 1) % self.num_micro_batches

Expand Down Expand Up @@ -1696,9 +1722,13 @@ def _handle_errors(self,
self._enqueue_responses(error_responses.items())

def _terminate_request(self, request: LlmRequest):
if self.kv_connector_manager is None:
self.resource_manager.free_resources(request)
if self._disagg_pp_termination_handler is not None:
self._disagg_pp_termination_handler.terminate(request)
else:
self._do_terminate_request(request)

def _do_terminate_request(self, request: LlmRequest):
if self.kv_connector_manager is not None:
# Only call request_finished on the connector if the request has already been added to the kv cache manager.
try:
cache_block_ids = self.kv_cache_manager.get_cache_indices(
Expand All @@ -1711,6 +1741,8 @@ def _terminate_request(self, request: LlmRequest):
if not self.kv_connector_manager.request_finished(
request, cache_block_ids):
self.resource_manager.free_resources(request)
else:
self.resource_manager.free_resources(request)

@nvtx_range("_handle_canceled_requests")
def _handle_canceled_requests(self):
Expand Down Expand Up @@ -1919,3 +1951,104 @@ def _remove_inflight_ids(self, scheduled_requests):
"""Remove reqids of current requests from self.inflight_req_ids."""
for req in scheduled_requests.all_requests():
self.inflight_req_ids.erase(req.request_id)


class DisaggPPTerminationHandler:
"""Handles termination synchronization across pipeline parallel ranks under disaggregated serving.

We require synchronization when terminating requests in disaggregated PP when
KV cache reuse is enabled. All PP ranks need to reach consensus before freeing
resources to avoid a NCCL hang.
"""

def __init__(self, num_micro_batches: int, dist):
self.dist = dist
# Request termination synchronization across PP ranks
# {request_id: {'ready_to_terminate': set{ranks}, 'terminated': {ranks}}}
self.pending_termination = {}
self.termination_handles = [None] * num_micro_batches
# Local map from request_id -> local LlmRequest awaiting consensus termination
self.local_termination = {}

def terminate(self, request: LlmRequest) -> bool:
req_key = request.py_request_id
self.local_termination[req_key] = request
state = self.pending_termination.get(req_key, None)
if state is None:
state = {'ready_to_terminate': set(), 'terminated': set()}
self.pending_termination[req_key] = state
if self.dist.rank not in state['ready_to_terminate']:
state['ready_to_terminate'].add(self.dist.rank)
return False

def sync(self, microbatch_id: int) -> List[LlmRequest]:
"""Ring-communicate pending termination state and apply local terminations upon consensus.

Each rank sends its current pending_termination snapshot to the next PP rank
and receives the previous rank's snapshot. After merging, apply any terminations
that have reached consensus (i.e., all PP ranks are ready).
"""
snapshot = {
req_id: {
'ready_to_terminate': state.get('ready_to_terminate', set()),
'terminated': state.get('terminated', set()),
}
for req_id, state in self.pending_termination.items()
}

if self.termination_handles[microbatch_id] is not None:
self.termination_handles[microbatch_id].wait()

term_tag = TERMINATION_COMM_TAG_BASE + microbatch_id
self.termination_handles[microbatch_id] = self.dist.isend_object(
snapshot,
dest=self.dist.next_pp_rank,
tag=term_tag,
)
remote_state = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=term_tag,
)
logger.debug(
f"received remote state for microbatch {microbatch_id}, prev pp rank: {self.dist.prev_pp_rank} state {remote_state}"
)

if remote_state:
for req_id, state in remote_state.items():
local = self.pending_termination.get(req_id)
if local is None:
self.pending_termination[req_id] = {
'ready_to_terminate': state.get('ready_to_terminate',
set()),
'terminated': state.get('terminated', set()),
}
else:
for key in ('ready_to_terminate', 'terminated'):
for r in state.get(key, []):
if r not in local[key]:
local[key].add(r)

requests_to_terminate = []
to_delete = []
for req_id, state in self.pending_termination.items():
ready = state.get('ready_to_terminate', set())
done = state.get('terminated', set())
# If all PP ranks are ready to terminate the request, we can free the resources
if len(ready) >= self.dist.pp_size and self.dist.rank not in done:
local_req = self.local_termination.get(req_id)
if local_req is not None:
requests_to_terminate.append(local_req)
done.add(self.dist.rank)
if len(done) >= self.dist.pp_size:
to_delete.append(req_id)
if req_id in self.local_termination:
self.local_termination.pop(req_id, None)
for req_id in to_delete:
self.pending_termination.pop(req_id, None)

return requests_to_terminate

def cleanup(self):
for h in self.termination_handles:
if h is not None:
h.wait()
3 changes: 0 additions & 3 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,9 @@ examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-re
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5421989)
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:1-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5421989)
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5433545)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5434451)
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-27b-it] SKIP (https://nvbugs/5434451)
examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-3-1b-it] SKIP (https://nvbugs/5434451)
Expand All @@ -278,7 +276,6 @@ triton_server/test_triton.py::test_gpt_ib_ptuning[gpt-ib-ptuning] SKIP (https://
triton_server/test_triton.py::test_mistral_ib_mm[mistral-ib-mm] SKIP (https://nvbugs/5371343)
triton_server/test_triton.py::test_t5_ib[t5-ib] SKIP (https://nvbugs/5456482)
triton_server/test_triton_llm.py::test_gpt_speculative_decoding_bls[False-False-1---False-True-True-0-128-disableDecoupleMode-inflight_fused_batching-disableTrtOverlap-0.2-guaranteed_no_evict---1-1-1-False-ensemble] SKIP (https://nvbugs/5456485)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_eagle3] SKIP (https://nvbugs/5437384)
llmapi/test_llm_examples.py::test_llmapi_speculative_decoding_mtp SKIP (https://nvbugs/5461796)
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_gather_generation_logits_cuda_graph SKIP (https://nvbugs/5365525)
Expand Down