diff --git a/cpp/tensorrt_llm/executor/executorImpl.cpp b/cpp/tensorrt_llm/executor/executorImpl.cpp index 6638992b452..664940dce69 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.cpp +++ b/cpp/tensorrt_llm/executor/executorImpl.cpp @@ -947,30 +947,32 @@ std::vector Executor::Impl::awaitResponses(std::optional responses; std::unique_lock lck(mResponsesMtx); - auto pred = [&mShutdown = mShutdown, &resp = this->mResponses]() -> bool { return !resp.empty() || mShutdown; }; - auto storeResponses = [this, &resp = this->mResponses, &responses]() + auto pred = [this]() -> bool { return !mResponses.empty() || mShutdown; }; + auto storeResponses = [this]() { - for (auto it = resp.cbegin(); it != resp.cend();) + std::vector responses; + for (auto it = mResponses.begin(); it != mResponses.end();) { responses.insert(responses.end(), it->second.begin(), it->second.end()); addTerminatedReqId(it->second, it->first); - resp.erase(it++); + it = mResponses.erase(it); } + return responses; }; + std::vector responses; if (timeout) { if (mResponsesCv.wait_for(lck, timeout.value(), pred)) { - storeResponses(); + responses = storeResponses(); } } else { mResponsesCv.wait(lck, pred); - storeResponses(); + responses = storeResponses(); } return responses; } @@ -980,15 +982,16 @@ std::vector Executor::Impl::awaitResponses( { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called"); checkParallelApiUsage(__func__); - std::vector responses; std::unique_lock lck(mResponsesMtx); - auto pred = [&mShutdown = mShutdown, &resp = this->mResponses, reqId]() -> bool - { return (resp.find(reqId) != resp.end() && !resp.at(reqId).empty()) || mShutdown; }; - auto storeIdResponse = [this, &resp = this->mResponses, &responses, reqId]() + auto pred = [this, reqId]() -> bool + { return (mResponses.find(reqId) != mResponses.end() && !mResponses.at(reqId).empty()) || mShutdown; }; + auto storeIdResponse = [this, reqId]() { - responses.swap(resp.at(reqId)); - resp.erase(reqId); + std::vector responses; + responses.swap(mResponses.at(reqId)); + mResponses.erase(reqId); addTerminatedReqId(responses, reqId); + return responses; }; // We don't process a terminated request again. Terminated request is defined as a response @@ -1005,17 +1008,18 @@ std::vector Executor::Impl::awaitResponses( return {Response(reqId, err)}; } + std::vector responses; if (timeout) { if (mResponsesCv.wait_for(lck, timeout.value(), pred)) { - storeIdResponse(); + responses = storeIdResponse(); } } else { mResponsesCv.wait(lck, pred); - storeIdResponse(); + responses = storeIdResponse(); } return responses; } @@ -1025,26 +1029,27 @@ std::vector> Executor::Impl::awaitResponses( { TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called"); checkParallelApiUsage(__func__); - std::vector> v(requestIds.size()); + std::vector> responses; + responses.reserve(requestIds.size()); if (timeout) { auto const start_time = std::chrono::high_resolution_clock::now(); - for (unsigned i = 0; i < v.size(); ++i) + for (auto const requestId : requestIds) { auto const elapsed_ms = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - start_time); - v[i] = awaitResponses(requestIds[i], - timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0}); + responses.emplace_back(awaitResponses( + requestId, timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0})); } } else { - for (unsigned i = 0; i < v.size(); ++i) + for (auto const requestId : requestIds) { - v[i] = awaitResponses(requestIds[i]); + responses.emplace_back(awaitResponses(requestId)); } } - return v; + return responses; } SizeType32 Executor::Impl::getNumResponsesReady(std::optional const& optId) const @@ -1663,7 +1668,7 @@ void Executor::Impl::terminateActiveRequests(RequestList& activeRequests, std::s } // Remove from the requestList - activeRequests.erase(it++); + it = activeRequests.erase(it); } } diff --git a/cpp/tensorrt_llm/executor/executorImpl.h b/cpp/tensorrt_llm/executor/executorImpl.h index 8ab2a35a680..7d34cbdf382 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.h +++ b/cpp/tensorrt_llm/executor/executorImpl.h @@ -107,7 +107,7 @@ class Executor::Impl std::vector awaitResponses(std::optional const& timeout = std::nullopt); std::vector awaitResponses( - IdType const& optId, std::optional const& optTimeout = std::nullopt); + IdType const& reqId, std::optional const& optTimeout = std::nullopt); std::vector> awaitResponses( std::vector const& requestIds, std::optional const& timeout); diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index 0ef28ffcc59..fed95108cb2 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -933,13 +933,17 @@ def _initialize_and_fill_output( output_ids = [[[] for _ in range(num_sequences)] for _ in range(len(request_ids))] - multi_responses = self.session.await_responses(request_ids) - responses = [ - response for responses in multi_responses for response in responses - ] + all_responses = [] + finished_request_ids = set() + while finished_request_ids != set(request_ids): + responses = self.session.await_responses() + for response in responses: + if response.result.is_final: + finished_request_ids.add(response.request_id) + all_responses.extend(responses) return self._fill_output( - responses=responses, + responses=all_responses, output_ids=output_ids, end_id=end_id, return_dict=return_dict,