Skip to content

Commit ce6aba3

Browse files
committed
fixup! [nvbugs/5274894] fix: Sort requests for functional correctness and performance (adapted from NVIDIA#4608)
Signed-off-by: Robin Kobus <[email protected]>
1 parent 75f6035 commit ce6aba3

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ std::tuple<RequestVector, RequestVector> MicroBatchScheduler::operator()(Request
309309
}
310310
}
311311

312-
utils::sortRequests(contextRequests, generationRequests);
312+
utils::sortRequests(contextRequests, generationRequests, !allContextRequestsFit);
313313

314314
TLLM_LOG_DEBUG(
315315
"batchSize (num ctx/enc requests + num gen requests): %u", contextRequests.size() + generationRequests.size());

cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,31 @@ TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector
3939
return requestIds;
4040
}
4141

42-
void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests)
42+
void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests, bool chunksPresent)
4343
{
4444
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
4545

46-
// Move context requests that reached the last context chunk to the end of the vector.
47-
// This order is required for moveFinishedContextRequestsToGeneration.
48-
auto firstFinished = std::partition(contextRequests.begin(), contextRequests.end(),
49-
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
50-
5146
auto sortByLoraId = [](RequestVector::iterator begin, RequestVector::iterator end)
5247
{
5348
std::sort(
5449
begin, end, [](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); });
5550
};
5651

57-
// Sort context requests by lora task id, but keep finished requests separate.
58-
sortByLoraId(contextRequests.begin(), firstFinished);
59-
sortByLoraId(firstFinished, contextRequests.end());
60-
// Sort generation requests by lora task id.
52+
if (chunksPresent)
53+
{
54+
// Move context requests that reached the last context chunk to the end of the vector.
55+
// This order is required for moveFinishedContextRequestsToGeneration.
56+
auto firstFinished = std::partition(contextRequests.begin(), contextRequests.end(),
57+
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
58+
59+
// Sort context requests by lora task id, but keep finished requests separate.
60+
sortByLoraId(contextRequests.begin(), firstFinished);
61+
sortByLoraId(firstFinished, contextRequests.end());
62+
}
63+
else
64+
{
65+
sortByLoraId(contextRequests.begin(), contextRequests.end());
66+
}
6167
sortByLoraId(generationRequests.begin(), generationRequests.end());
6268

6369
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);

cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector
4040
//! Sort requests by lora task id for performance.
4141
//! @param contextRequests The context requests.
4242
//! @param generationRequests The generation requests.
43-
void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests);
43+
//! @param chunksPresent Whether context chunks are present.
44+
void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests, bool chunksPresent);
4445

4546
//! @brief Move finished context requests to generation requests.
4647
//! @details This function assumes that the context requests are sorted so that requests with isLastContextChunk() are

0 commit comments

Comments
 (0)