Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions cpp/tensorrt_llm/executor/executorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,30 +947,32 @@ std::vector<Response> Executor::Impl::awaitResponses(std::optional<std::chrono::
{
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
checkParallelApiUsage(__func__);
std::vector<Response> responses;
std::unique_lock<std::mutex> lck(mResponsesMtx);
auto pred = [&mShutdown = mShutdown, &resp = this->mResponses]() -> bool { return !resp.empty() || mShutdown; };
auto storeResponses = [this, &resp = this->mResponses, &responses]()
auto pred = [this]() -> bool { return !mResponses.empty() || mShutdown; };
auto storeResponses = [this]()
{
for (auto it = resp.cbegin(); it != resp.cend();)
std::vector<Response> responses;
for (auto it = mResponses.begin(); it != mResponses.end();)
{
responses.insert(responses.end(), it->second.begin(), it->second.end());
addTerminatedReqId(it->second, it->first);
resp.erase(it++);
it = mResponses.erase(it);
}
return responses;
};

std::vector<Response> responses;
if (timeout)
{
if (mResponsesCv.wait_for(lck, timeout.value(), pred))
{
storeResponses();
responses = storeResponses();
}
}
else
{
mResponsesCv.wait(lck, pred);
storeResponses();
responses = storeResponses();
}
return responses;
}
Expand All @@ -980,15 +982,16 @@ std::vector<Response> Executor::Impl::awaitResponses(
{
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
checkParallelApiUsage(__func__);
std::vector<Response> responses;
std::unique_lock<std::mutex> lck(mResponsesMtx);
auto pred = [&mShutdown = mShutdown, &resp = this->mResponses, reqId]() -> bool
{ return (resp.find(reqId) != resp.end() && !resp.at(reqId).empty()) || mShutdown; };
auto storeIdResponse = [this, &resp = this->mResponses, &responses, reqId]()
auto pred = [this, reqId]() -> bool
{ return (mResponses.find(reqId) != mResponses.end() && !mResponses.at(reqId).empty()) || mShutdown; };
auto storeIdResponse = [this, reqId]()
{
responses.swap(resp.at(reqId));
resp.erase(reqId);
std::vector<Response> responses;
responses.swap(mResponses.at(reqId));
mResponses.erase(reqId);
addTerminatedReqId(responses, reqId);
return responses;
};

// We don't process a terminated request again. Terminated request is defined as a response
Expand All @@ -1005,17 +1008,18 @@ std::vector<Response> Executor::Impl::awaitResponses(
return {Response(reqId, err)};
}

std::vector<Response> responses;
if (timeout)
{
if (mResponsesCv.wait_for(lck, timeout.value(), pred))
{
storeIdResponse();
responses = storeIdResponse();
}
}
else
{
mResponsesCv.wait(lck, pred);
storeIdResponse();
responses = storeIdResponse();
}
return responses;
}
Expand All @@ -1025,26 +1029,27 @@ std::vector<std::vector<Response>> Executor::Impl::awaitResponses(
{
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
checkParallelApiUsage(__func__);
std::vector<std::vector<Response>> v(requestIds.size());
std::vector<std::vector<Response>> responses;
responses.reserve(requestIds.size());
if (timeout)
{
auto const start_time = std::chrono::high_resolution_clock::now();
for (unsigned i = 0; i < v.size(); ++i)
for (auto const requestId : requestIds)
{
auto const elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - start_time);
v[i] = awaitResponses(requestIds[i],
timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0});
responses.emplace_back(awaitResponses(
requestId, timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0}));
}
}
else
{
for (unsigned i = 0; i < v.size(); ++i)
for (auto const requestId : requestIds)
{
v[i] = awaitResponses(requestIds[i]);
responses.emplace_back(awaitResponses(requestId));
}
}
return v;
return responses;
}

SizeType32 Executor::Impl::getNumResponsesReady(std::optional<IdType> const& optId) const
Expand Down Expand Up @@ -1663,7 +1668,7 @@ void Executor::Impl::terminateActiveRequests(RequestList& activeRequests, std::s
}

// Remove from the requestList
activeRequests.erase(it++);
it = activeRequests.erase(it);
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/executor/executorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class Executor::Impl
std::vector<Response> awaitResponses(std::optional<std::chrono::milliseconds> const& timeout = std::nullopt);

std::vector<Response> awaitResponses(
IdType const& optId, std::optional<std::chrono::milliseconds> const& optTimeout = std::nullopt);
IdType const& reqId, std::optional<std::chrono::milliseconds> const& optTimeout = std::nullopt);

std::vector<std::vector<Response>> awaitResponses(
std::vector<IdType> const& requestIds, std::optional<std::chrono::milliseconds> const& timeout);
Expand Down
14 changes: 9 additions & 5 deletions tensorrt_llm/runtime/model_runner_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,13 +933,17 @@ def _initialize_and_fill_output(
output_ids = [[[] for _ in range(num_sequences)]
for _ in range(len(request_ids))]

multi_responses = self.session.await_responses(request_ids)
responses = [
response for responses in multi_responses for response in responses
]
all_responses = []
finished_request_ids = set()
while finished_request_ids != set(request_ids):
responses = self.session.await_responses()
for response in responses:
if response.result.is_final:
finished_request_ids.add(response.request_id)
all_responses.extend(responses)

return self._fill_output(
responses=responses,
responses=all_responses,
output_ids=output_ids,
end_id=end_id,
return_dict=return_dict,
Expand Down