Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -953,10 +953,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
= (*mMicroBatchScheduler)(fittingRequests, mInflightReqIds, mMaxBatchSizeRuntime, mMaxNumTokensRuntime);
TLLM_CHECK(currRequests.size() <= static_cast<size_t>(getMaxBatchSize()));

// Move context requests that reached the last context chunk to the end of the vector.
// This order is required for moveFinishedContextRequestsToGeneration.
std::partition(currRequests.contextRequests.begin(), currRequests.contextRequests.end(),
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
utils::sortRequests(currRequests);

(*mPauseRequests)(requestsToPause, mInflightReqIds, mReqIdsToPause, false, *mSeqSlotManager, mKvCacheManager,
mCrossKvCacheManager, mPeftCacheManager);
Expand Down Expand Up @@ -984,8 +981,6 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
}
}

utils::sortByLoraId(currRequests);

(*mAssignReqSeqSlots)(*mSeqSlotManager, currRequests.contextRequests, currRequests.generationRequests);

if (mKvCacheManager)
Expand Down
30 changes: 20 additions & 10 deletions cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,26 @@ TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector
return requestIds;
}

void sortByLoraId(ScheduledRequests& scheduledRequests)
void sortRequests(ScheduledRequests& scheduledRequests)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

auto sortRequests = [](RequestVector& requests)
// Move context requests that reached the last context chunk to the end of the vector.
// This order is required for moveFinishedContextRequestsToGeneration.
auto firstFinished = std::partition(scheduledRequests.contextRequests.begin(),
scheduledRequests.contextRequests.end(), [](auto const& llmReq) { return !llmReq->isLastContextChunk(); });

auto sortByLoraId = [](RequestVector::iterator begin, RequestVector::iterator end)
{
std::sort(requests.begin(), requests.end(),
[](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); });
std::sort(
begin, end, [](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); });
};
sortRequests(scheduledRequests.contextRequests);
sortRequests(scheduledRequests.generationRequests);

// Sort context requests by lora task id, but keep finished requests separate.
sortByLoraId(scheduledRequests.contextRequests.begin(), firstFinished);
sortByLoraId(firstFinished, scheduledRequests.contextRequests.end());
// Sort generation requests by lora task id.
sortByLoraId(scheduledRequests.generationRequests.begin(), scheduledRequests.generationRequests.end());

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
Expand All @@ -60,11 +69,12 @@ void moveFinishedContextRequestsToGeneration(ScheduledRequests& scheduledRequest

auto& contextRequests = scheduledRequests.contextRequests;
auto& generationRequests = scheduledRequests.generationRequests;

auto firstFinished = std::partition(contextRequests.begin(), contextRequests.end(),
[](auto const& llmReq) { return !llmReq->isContextFinished(); });
auto firstFinished = std::find_if(
contextRequests.begin(), contextRequests.end(), [](auto const& llmReq) { return llmReq->isContextFinished(); });
TLLM_LOG_DEBUG(
"Moving %ld finished context requests to generation.", std::distance(firstFinished, contextRequests.end()));
"Found %ld unfinished chunked context requests. Found %ld finished context requests, moving them to "
"generation.",
std::distance(contextRequests.begin(), firstFinished), std::distance(firstFinished, contextRequests.end()));
generationRequests.insert(generationRequests.begin(), std::make_move_iterator(firstFinished),
std::make_move_iterator(contextRequests.end()));
contextRequests.erase(firstFinished, contextRequests.end());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ using OptionalRef = common::OptionalRef<T>;

TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector const& generationRequests);

void sortByLoraId(ScheduledRequests& scheduledRequests);
//! @brief Sort requests for functional correctness and performance.
//! @details Sort context requests for moveFinishedContextRequestsToGeneration.
//! Sort requests by lora task id for performance.
//! @param scheduledRequests The scheduled context and generation requests.
void sortRequests(ScheduledRequests& scheduledRequests);

//! @brief Move finished context requests to generation requests.
//! @details This function assumes that the context requests are sorted so that requests with isLastContextChunk() are
Expand Down