Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions cpp/tensorrt_llm/runtime/moeLoadBalancer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,15 @@ void MoeLoadBalancer::finalizeModel()
{
layer->finalizeModel();
}
generateUpdatePlan();
startThreads();
if (mLayerUpdatesPerIter > 0)
{
generateUpdatePlan();
startThreads();
}
else
{
mWorkerThreadStopped = true;
}
mModelFinalized = true;
}

Expand All @@ -751,10 +758,12 @@ void MoeLoadBalancer::startIter(int64_t iterId, bool enableStatistic, bool enabl
TLLM_CHECK_WITH_INFO(mIterId + 1 == iterId, "Expected iterId=%ld, but got %ld", mIterId + 1, iterId);

mIterId = iterId;
mStatisticEnabled = enableStatistic;
// disable update for warm up iters.
bool isWarmUpIter = mIterId <= mWarmUpUntilIter;
mUpdateWeightsEnabled = enableUpdateWeights && !isWarmUpIter;
bool fixedUpdateWeightsEnabled = enableUpdateWeights && !isWarmUpIter;

IterInfo iterInfo{iterId, enableStatistic, fixedUpdateWeightsEnabled};
mIterInfoQueue.push(iterInfo);
mWorkerThreadCondition.notify_one();
}

Expand All @@ -780,21 +789,24 @@ void MoeLoadBalancer::shutdown()
void MoeLoadBalancer::workerThread()
{
TLLM_CUDA_CHECK(cudaSetDevice(mCudaDeviceId));
int64_t iterId = -1;
while (true)
{
int64_t iterId;
bool iterUpdateWeightsEnabled, iterStatisticEnabled;
{
std::unique_lock<std::mutex> lock(mWorkerThreadMutex);
mWorkerThreadCondition.wait(lock, [this] { return mWaitIterId == mIterId || mWorkerThreadStopped; });
iterId = mIterId;
if (mWorkerThreadStopped)
mWorkerThreadCondition.wait(lock, [this] { return !mIterInfoQueue.empty() || mWorkerThreadStopped; });
if (mIterInfoQueue.empty() && mWorkerThreadStopped)
{
break;
}
mWaitIterId = mIterId + 1;
iterUpdateWeightsEnabled = mUpdateWeightsEnabled;
iterStatisticEnabled = mStatisticEnabled;
auto iterInfo = mIterInfoQueue.front();
mIterInfoQueue.pop();
TLLM_CHECK_WITH_INFO(iterInfo.iterId == iterId + 1, "Jump detected, iterId=%ld, but got next iterId=%ld",
iterId, iterInfo.iterId);
iterId = iterInfo.iterId;
iterUpdateWeightsEnabled = iterInfo.updateWeightsEnabled;
iterStatisticEnabled = iterInfo.statisticEnabled;
}
for (int layerId = 0; static_cast<size_t>(layerId) < mLayers.size(); ++layerId)
{
Expand Down
13 changes: 10 additions & 3 deletions cpp/tensorrt_llm/runtime/moeLoadBalancer.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ class MoeLoadBalancer
std::mutex mWorkerThreadMutex;
std::condition_variable mWorkerThreadCondition;
bool mWorkerThreadStopped = false;
int64_t mWaitIterId = 0;
int64_t mWarmUpUntilIter = -1;

// we use a separate thread to compute and update weights to avoid possible blocking for next layer due to slow
Expand All @@ -252,8 +251,16 @@ class MoeLoadBalancer
std::vector<std::shared_ptr<SingleLayerMoeLoadBalancer>> mLayers;

int64_t mIterId = -1;
bool mStatisticEnabled = true;
bool mUpdateWeightsEnabled = true;

struct IterInfo
{
int64_t iterId = -1;
bool statisticEnabled = true;
bool updateWeightsEnabled = true;
};

std::queue<IterInfo> mIterInfoQueue;

bool mModelFinalized = false;

int mEpRank = 0;
Expand Down