Skip to content

Commit b69a4ef

Browse files
committed
TRTLLM-7731 KV cache transmission in disagg with CP on gen side
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 77657a1 commit b69a4ef

File tree

6 files changed

+244
-132
lines changed

6 files changed

+244
-132
lines changed

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
179179
RequestInfo requestInfo(requestId, mSelfState);
180180

181181
auto disableSelectiveCacheTransfer = common::getEnvDisableSelectiveCacheTransfer()
182-
|| (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1);
182+
|| (mFormatter->getCacheManager()->getBlockManager().getNumPools() > 1)
183+
|| (mSelfState.getCacheState().value().getParallelConfig().mDPsize > 1);
183184
if (!disableSelectiveCacheTransfer)
184185
{
185186
auto* cacheManager = mFormatter->getCacheManager();

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,39 @@
3737
namespace tensorrt_llm::batch_manager::kv_cache_manager
3838
{
3939

40+
int getBlockNumAccountingForCP(int cpRank, int cpSize, int numTotalBlocks, bool strict)
41+
{
42+
TLLM_CHECK(cpRank >= 0 && cpRank < cpSize);
43+
if (cpSize == 1)
44+
{
45+
return numTotalBlocks;
46+
}
47+
// NOTE: Non-strict mode may over-allocate blocks when numTotalBlocks is not divisible by cpSize.
48+
// This is a known limitation and will be addressed in a future MR.
49+
if (!strict)
50+
{
51+
// Simple ceiling division.
52+
return (numTotalBlocks + cpSize - 1) / cpSize;
53+
}
54+
// In strict mode, blocks are distributed among CP ranks in a round-robin fashion as evenly as possible.
55+
// When the number of blocks is not divisible by cpSize, the remainder shall be distributed evenly among
56+
// lowest-indexed CP ranks (let's call them overflow ranks).
57+
int numBlocksCurrRank = numTotalBlocks / cpSize;
58+
if (numTotalBlocks % cpSize > cpRank)
59+
{
60+
numBlocksCurrRank++;
61+
}
62+
return numBlocksCurrRank;
63+
}
64+
4065
// some context rank in connection
4166
std::vector<size_t> MLACacheFormatter::pickRecvConnections(
4267
size_t numConnections, CacheState const& selfConfig, SizeType32 selfIdx, CacheState const& destConfig) const
4368
{
4469

4570
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
71+
// This function is called only by gen side and we only support CPSize=1 on context size.
72+
TLLM_CHECK(targetInfo.mDomainCPSize);
4673
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
4774
std::vector<size_t> ret;
4875
// targetInfo , mRanks [tpranks, dpranks]
@@ -97,14 +124,11 @@ void MLACacheFormatter::format(TransferSession& session)
97124
auto& bufferManager = session.getBufferManager();
98125
TLLM_CHECK_WITH_INFO(llmRequest.mSamplingConfig.beamWidth == 1, "Currently only supports beam width 1.");
99126
TLLM_CHECK(!connections.empty());
100-
// diff start
101127
if (!needSendCache(selfConfig, destConfig, selfIdx))
102128
{
103129
return;
104130
}
105131

106-
// diff end
107-
108132
auto const numPools = mCacheManager->getBlockManager().getNumPools();
109133
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest);
110134

@@ -147,43 +171,48 @@ void MLACacheFormatter::format(TransferSession& session)
147171
return;
148172
}
149173

150-
auto cacheBlockSize = inputKvCacheBlocks.at(0)->getSize();
151-
152-
auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend();
153-
// diff start
154-
155174
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
156-
auto ppRank = selfIdx
157-
/ (selfConfig.getParallelConfig().mTensorParallelism * selfConfig.getParallelConfig().mContextParallelism);
158-
int selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank);
159175
size_t pPDomainSize = targetInfo.mDomainPPSize;
176+
size_t cPDomainSize = targetInfo.mDomainCPSize;
177+
160178
auto getBufferSizeForTarget = [&]()
161179
{
162-
std::vector<size_t> bufferSizeForTarget(pPDomainSize, 0);
163-
size_t cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum;
164-
for (size_t i = 0; i < pPDomainSize; i++)
180+
auto const ppRank = selfIdx
181+
/ (selfConfig.getParallelConfig().mTensorParallelism * selfConfig.getParallelConfig().mContextParallelism);
182+
auto const selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank);
183+
auto const cacheBlockSize = inputKvCacheBlocks.at(0)->getSize();
184+
auto const blockSizePerLayer = cacheBlockSize / selfAttentionLayerNum;
185+
std::vector<size_t> bufferSizeForTarget(pPDomainSize * cPDomainSize, 0);
186+
for (size_t ppDomainIdx = 0; ppDomainIdx < pPDomainSize; ppDomainIdx++)
165187
{
166-
auto layerNum = targetInfo.getPeerPPDomainLayerNum(i);
167-
bufferSizeForTarget[i] = cacheSizePerLayer * layerNum;
188+
auto const peerAttentionLayerNum = targetInfo.getPeerPPDomainLayerNum(ppDomainIdx);
189+
for (size_t cpDomainIdx = 0; cpDomainIdx < cPDomainSize; cpDomainIdx++)
190+
{
191+
auto const idx = cpDomainIdx * pPDomainSize + ppDomainIdx;
192+
// Note: contextCP is always 1. So, cpDomainSize == genCPSize and cpDomainIdx == genCPRank.
193+
auto const peerBlockNum
194+
= getBlockNumAccountingForCP(cpDomainIdx, cPDomainSize, blockNum, /*strict=*/false);
195+
bufferSizeForTarget[idx] = blockSizePerLayer * peerAttentionLayerNum * peerBlockNum;
196+
}
168197
}
169198
return bufferSizeForTarget;
170199
};
171200
auto bufferEleSizes = getBufferSizeForTarget();
201+
auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend();
172202
auto result = mCacheTransBufferManager->getOrAllocateSendBuffers(
173-
cacheBufferId, static_cast<int>(pPDomainSize), bufferEleSizes, bufferManager);
203+
cacheBufferId, static_cast<int>(pPDomainSize * cPDomainSize), bufferEleSizes, bufferManager);
174204
auto& outputSplitCaches = std::get<0>(result);
175205
auto& bufferCoverTargetNum = std::get<1>(result);
176206
auto& onlyUseDynamicBuffer = std::get<2>(result);
177207
auto* agentConnnecion = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[0]);
178208
if (agentConnnecion != nullptr)
179209
{
180-
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == pPDomainSize, "Agent need all buffer pre-allocated");
210+
TLLM_CHECK_WITH_INFO(
211+
bufferCoverTargetNum == pPDomainSize * cPDomainSize, "Agent need all buffer pre-allocated");
181212
TLLM_CHECK(onlyUseDynamicBuffer == false);
182213
}
183-
// diff end
184-
185-
// The size of outputSplitCaches should be equal to pPDomainSize
186214

215+
// The size of outputSplitCaches should be equal to pPDomainSize * cPDomainSize.
187216
SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(0);
188217
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocksPerWindow;
189218
inputKvCacheBlocksPerWindow.emplace(window, inputKvCacheBlocks);
@@ -203,7 +232,7 @@ void MLACacheFormatter::format(TransferSession& session)
203232

204233
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
205234
auto startTime = std::chrono::steady_clock::now();
206-
auto cacheIdx = processIdx % pPDomainSize;
235+
auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize);
207236
if (cacheIdx < bufferCoverTargetNum)
208237
{
209238
size_t size = outputSplitCaches.at(cacheIdx)->getSizeInBytes();
@@ -259,7 +288,8 @@ void MLACacheFormatter::format(TransferSession& session)
259288
else
260289
{
261290
// concurrency num
262-
auto concurrencyNum = std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), pPDomainSize);
291+
auto concurrencyNum
292+
= std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), pPDomainSize * cPDomainSize);
263293

264294
auto remainSendNum = connections.size();
265295

@@ -307,9 +337,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
307337
auto& bufferManager = session.getBufferManager();
308338
auto arrivalTime = llmRequest.getPerfMetrics().timingMetrics.arrivalTime;
309339
bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point();
310-
// diff start
311340
auto pickUpConnections = pickRecvConnections(connections.size(), selfConfig, selfIdx, destConfig);
312-
// diff end
313341
auto blockRange = getBlockRangeForReceiving(mCacheManager, llmRequest);
314342
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
315343
std::vector<runtime::ITensor::SharedPtr> outputBuffers;
@@ -364,23 +392,24 @@ void MLACacheFormatter::unformat(TransferSession& session)
364392
cacheBufferId = mCacheTransBufferManager->assignBufferIndexForRecv();
365393
}
366394

367-
auto cacheBlockSize = outputBuffers.at(0)->getSize();
368-
369395
auto targetNum = pickUpConnections.size();
370-
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
371-
auto ppRank = selfIdx
372-
/ (selfConfig.getParallelConfig().mTensorParallelism * selfConfig.getParallelConfig().mContextParallelism);
373-
auto selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank);
374-
TLLM_CHECK_WITH_INFO(selfAttentionLayerNum != 0, "selfAttentionLayerNum should not be 0");
375396

376397
auto getBufferSizeForTarget = [&]()
377398
{
399+
auto const targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
400+
auto const cacheBlockSize = outputBuffers.at(0)->getSize();
401+
auto const ppRank = selfIdx
402+
/ (selfConfig.getParallelConfig().mTensorParallelism
403+
* selfConfig.getParallelConfig().mContextParallelism);
404+
auto const selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank);
405+
TLLM_CHECK_WITH_INFO(selfAttentionLayerNum != 0, "selfAttentionLayerNum should not be 0");
378406
std::vector<size_t> bufferEleSizes(targetNum, 0);
379-
auto cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum;
407+
auto const cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum;
380408
for (size_t i = 0; i < targetNum; i++)
381409
{
382-
auto layerNum = targetInfo.getPeerPPDomainLayerNum(static_cast<SizeType32>(pickUpConnections[i]));
383-
bufferEleSizes[i] = cacheSizePerLayer * layerNum;
410+
auto const peerAttentionLayerNum
411+
= targetInfo.getPeerPPDomainLayerNum(static_cast<SizeType32>(pickUpConnections[i]));
412+
bufferEleSizes[i] = cacheSizePerLayer * peerAttentionLayerNum;
384413
}
385414
return bufferEleSizes;
386415
};
@@ -506,9 +535,10 @@ void MLACacheFormatter::unformat(TransferSession& session)
506535
outputCachesPerWindow.emplace(window, outputBuffers);
507536
NVTX3_SCOPED_RANGE(formatInputConcatenate);
508537

509-
// recvSplitCaches size == ppdomainsize
538+
// recvSplitCaches size == ppdomainsize * cpdomainsize.
510539
executor::kv_cache::concatKvCacheV2Dispatch(
511540
recvSplitCaches, outputCachesPerWindow, destConfig, selfConfig, selfIdx, bufferManager);
541+
bufferManager.getStream().synchronize();
512542
}
513543
bufferManager.getStream().synchronize();
514544
}
@@ -581,14 +611,6 @@ void MLACacheFormatter::unformat(TransferSession& session)
581611
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
582612
return false;
583613
}
584-
if (selfConfig.getParallelConfig().mContextParallelism != 1
585-
|| destConfig.getParallelConfig().mContextParallelism != 1)
586-
{
587-
TLLM_LOG_WARNING(
588-
"MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
589-
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
590-
return false;
591-
}
592614
if (destConfig.getParallelConfig().mEnableAttentionDP
593615
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
594616
{

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,24 @@
2222
namespace tensorrt_llm::batch_manager::kv_cache_manager
2323
{
2424

25+
/**
26+
* @brief Calculate the number of blocks allocated to a specific Context Parallelism (CP) rank.
27+
*
28+
* This function determines how many blocks should be allocated to a given CP rank when
29+
* distributing a total number of blocks across multiple CP ranks. It supports two distribution
30+
* modes: strict and non-strict.
31+
*
32+
* @param cpRank The rank (index) of the current CP process. Must be in range [0, cpSize).
33+
* @param cpSize The total number of CP ranks/processes in the parallel group.
34+
* @param numTotalBlocks The total number of blocks to be distributed across all CP ranks.
35+
* @param strict Flag controlling the distribution strategy:
36+
* - true: Use strict round-robin distribution with exact allocation
37+
* - false: Use ceiling division which may over-allocate
38+
*
39+
* @return The number of blocks allocated to the specified CP rank.
40+
*/
41+
int getBlockNumAccountingForCP(int cpRank, int cpSize, int numTotalBlocks, bool strict);
42+
2543
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
2644
// parallel topology is completely identical, making it the preferred method.
2745
class MLACacheFormatter final : public BaseCacheFormatter

0 commit comments

Comments
 (0)