Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
*/

#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
#include "tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h"
#include "tensorrt_llm/common/nvtxUtils.h"

namespace tle = tensorrt_llm::executor;

namespace tensorrt_llm::batch_manager
{

Expand Down Expand Up @@ -310,13 +309,7 @@ std::tuple<RequestVector, RequestVector> MicroBatchScheduler::operator()(Request
}
}

if (!allContextRequestsFit)
{
// Move context requests that reached the last context chunk to the end of the vector.
// This order is required for moveFinishedContextRequestsToGeneration.
std::partition(contextRequests.begin(), contextRequests.end(),
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
}
utils::sortRequests(contextRequests, generationRequests, !allContextRequestsFit);

TLLM_LOG_DEBUG(
"batchSize (num ctx/enc requests + num gen requests): %u", contextRequests.size() + generationRequests.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,6 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
}
}

utils::sortByLoraId(currRequests);

(*mAssignReqSeqSlots)(*mSeqSlotManager, currRequests.contextRequests, currRequests.generationRequests);

if (mKvCacheManager)
Expand Down
31 changes: 24 additions & 7 deletions cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,32 @@ TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector
return requestIds;
}

void sortByLoraId(ScheduledRequests& scheduledRequests)
void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests, bool chunksPresent)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

auto sortRequests = [](RequestVector& requests)
auto sortByLoraId = [](RequestVector::iterator begin, RequestVector::iterator end)
{
std::sort(requests.begin(), requests.end(),
[](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); });
std::sort(
begin, end, [](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); });
};
sortRequests(scheduledRequests.contextRequests);
sortRequests(scheduledRequests.generationRequests);

if (chunksPresent)
{
// Move context requests that reached the last context chunk to the end of the vector.
// This order is required for moveFinishedContextRequestsToGeneration.
auto firstFinished = std::partition(contextRequests.begin(), contextRequests.end(),
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });

// Sort context requests by lora task id, but keep finished requests separate.
sortByLoraId(contextRequests.begin(), firstFinished);
sortByLoraId(firstFinished, contextRequests.end());
}
else
{
sortByLoraId(contextRequests.begin(), contextRequests.end());
}
sortByLoraId(generationRequests.begin(), generationRequests.end());

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
Expand All @@ -63,7 +78,9 @@ void moveFinishedContextRequestsToGeneration(ScheduledRequests& scheduledRequest
auto firstFinished = std::find_if(
contextRequests.begin(), contextRequests.end(), [](auto const& llmReq) { return llmReq->isContextFinished(); });
TLLM_LOG_DEBUG(
"Moving %ld finished context requests to generation.", std::distance(firstFinished, contextRequests.end()));
"Found %ld unfinished chunked context requests. Found %ld finished context requests, moving them to "
"generation.",
std::distance(contextRequests.begin(), firstFinished), std::distance(firstFinished, contextRequests.end()));
generationRequests.insert(generationRequests.begin(), std::make_move_iterator(firstFinished),
std::make_move_iterator(contextRequests.end()));
contextRequests.erase(firstFinished, contextRequests.end());
Expand Down
8 changes: 7 additions & 1 deletion cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ using OptionalRef = common::OptionalRef<T>;

TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector const& generationRequests);

void sortByLoraId(ScheduledRequests& scheduledRequests);
//! @brief Sort requests for functional correctness and performance.
//! @details Sort context requests for moveFinishedContextRequestsToGeneration.
//! Sort requests by lora task id for performance.
//! @param contextRequests The context requests.
//! @param generationRequests The generation requests.
//! @param chunksPresent Whether context chunks are present.
void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests, bool chunksPresent);

//! @brief Move finished context requests to generation requests.
//! @details This function assumes that the context requests are sorted so that requests with isLastContextChunk() are
Expand Down