3737namespace tensorrt_llm ::batch_manager::kv_cache_manager
3838{
3939
40+ int getBlockNumAccountingForCP (int cpRank, int cpSize, int numTotalBlocks, bool strict)
41+ {
42+ TLLM_CHECK (cpRank >= 0 && cpRank < cpSize);
43+ if (cpSize == 1 )
44+ {
45+ return numTotalBlocks;
46+ }
47+ // NOTE: Non-strict mode may over-allocate blocks when numTotalBlocks is not divisible by cpSize.
48+ // This is a known limitation and will be addressed in a future MR.
49+ if (!strict)
50+ {
51+ // Simple ceiling division.
52+ return (numTotalBlocks + cpSize - 1 ) / cpSize;
53+ }
54+ // In strict mode, blocks are distributed among CP ranks in a round-robin fashion as evenly as possible.
55+ // When the number of blocks is not divisible by cpSize, the remainder shall be distributed evenly among
56+ // lowest-indexed CP ranks (let's call them overflow ranks).
57+ int numBlocksCurrRank = numTotalBlocks / cpSize;
58+ if (numTotalBlocks % cpSize > cpRank)
59+ {
60+ numBlocksCurrRank++;
61+ }
62+ return numBlocksCurrRank;
63+ }
64+
4065// some context rank in connection
4166std::vector<size_t > MLACacheFormatter::pickRecvConnections (
4267 size_t numConnections, CacheState const & selfConfig, SizeType32 selfIdx, CacheState const & destConfig) const
4368{
4469
4570 auto targetInfo = executor::kv_cache::targetIRanks (destConfig, selfConfig, selfIdx);
71+ // This function is called only by gen side and we only support CPSize=1 on context size.
72+ TLLM_CHECK (targetInfo.mDomainCPSize );
4673 TLLM_CHECK (numConnections == targetInfo.mIRanks .size ());
4774 std::vector<size_t > ret;
4875 // targetInfo , mRanks [tpranks, dpranks]
@@ -97,14 +124,11 @@ void MLACacheFormatter::format(TransferSession& session)
97124 auto & bufferManager = session.getBufferManager ();
98125 TLLM_CHECK_WITH_INFO (llmRequest.mSamplingConfig .beamWidth == 1 , " Currently only supports beam width 1." );
99126 TLLM_CHECK (!connections.empty ());
100- // diff start
101127 if (!needSendCache (selfConfig, destConfig, selfIdx))
102128 {
103129 return ;
104130 }
105131
106- // diff end
107-
108132 auto const numPools = mCacheManager ->getBlockManager ().getNumPools ();
109133 auto blockRange = getBlockRangeForSending (mCacheManager , llmRequest);
110134
@@ -147,43 +171,48 @@ void MLACacheFormatter::format(TransferSession& session)
147171 return ;
148172 }
149173
150- auto cacheBlockSize = inputKvCacheBlocks.at (0 )->getSize ();
151-
152- auto cacheBufferId = mCacheTransBufferManager ->assignBufferIndexForSend ();
153- // diff start
154-
155174 auto targetInfo = executor::kv_cache::targetIRanks (destConfig, selfConfig, selfIdx);
156- auto ppRank = selfIdx
157- / (selfConfig.getParallelConfig ().mTensorParallelism * selfConfig.getParallelConfig ().mContextParallelism );
158- int selfAttentionLayerNum = selfConfig.getParallelConfig ().mAttentionLayerNumPerPP .at (ppRank);
159175 size_t pPDomainSize = targetInfo.mDomainPPSize ;
176+ size_t cPDomainSize = targetInfo.mDomainCPSize ;
177+
160178 auto getBufferSizeForTarget = [&]()
161179 {
162- std::vector<size_t > bufferSizeForTarget (pPDomainSize, 0 );
163- size_t cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum;
164- for (size_t i = 0 ; i < pPDomainSize; i++)
180+ auto const ppRank = selfIdx
181+ / (selfConfig.getParallelConfig ().mTensorParallelism * selfConfig.getParallelConfig ().mContextParallelism );
182+ auto const selfAttentionLayerNum = selfConfig.getParallelConfig ().mAttentionLayerNumPerPP .at (ppRank);
183+ auto const cacheBlockSize = inputKvCacheBlocks.at (0 )->getSize ();
184+ auto const blockSizePerLayer = cacheBlockSize / selfAttentionLayerNum;
185+ std::vector<size_t > bufferSizeForTarget (pPDomainSize * cPDomainSize, 0 );
186+ for (size_t ppDomainIdx = 0 ; ppDomainIdx < pPDomainSize; ppDomainIdx++)
165187 {
166- auto layerNum = targetInfo.getPeerPPDomainLayerNum (i);
167- bufferSizeForTarget[i] = cacheSizePerLayer * layerNum;
188+ auto const peerAttentionLayerNum = targetInfo.getPeerPPDomainLayerNum (ppDomainIdx);
189+ for (size_t cpDomainIdx = 0 ; cpDomainIdx < cPDomainSize; cpDomainIdx++)
190+ {
191+ auto const idx = cpDomainIdx * pPDomainSize + ppDomainIdx;
192+ // Note: contextCP is always 1. So, cpDomainSize == genCPSize and cpDomainIdx == genCPRank.
193+ auto const peerBlockNum
194+ = getBlockNumAccountingForCP (cpDomainIdx, cPDomainSize, blockNum, /* strict=*/ false );
195+ bufferSizeForTarget[idx] = blockSizePerLayer * peerAttentionLayerNum * peerBlockNum;
196+ }
168197 }
169198 return bufferSizeForTarget;
170199 };
171200 auto bufferEleSizes = getBufferSizeForTarget ();
201+ auto cacheBufferId = mCacheTransBufferManager ->assignBufferIndexForSend ();
172202 auto result = mCacheTransBufferManager ->getOrAllocateSendBuffers (
173- cacheBufferId, static_cast <int >(pPDomainSize), bufferEleSizes, bufferManager);
203+ cacheBufferId, static_cast <int >(pPDomainSize * cPDomainSize ), bufferEleSizes, bufferManager);
174204 auto & outputSplitCaches = std::get<0 >(result);
175205 auto & bufferCoverTargetNum = std::get<1 >(result);
176206 auto & onlyUseDynamicBuffer = std::get<2 >(result);
177207 auto * agentConnnecion = dynamic_cast <executor::kv_cache::AgentConnection const *>(connections[0 ]);
178208 if (agentConnnecion != nullptr )
179209 {
180- TLLM_CHECK_WITH_INFO (bufferCoverTargetNum == pPDomainSize, " Agent need all buffer pre-allocated" );
210+ TLLM_CHECK_WITH_INFO (
211+ bufferCoverTargetNum == pPDomainSize * cPDomainSize, " Agent need all buffer pre-allocated" );
181212 TLLM_CHECK (onlyUseDynamicBuffer == false );
182213 }
183- // diff end
184-
185- // The size of outputSplitCaches should be equal to pPDomainSize
186214
215+ // The size of outputSplitCaches should be equal to pPDomainSize * cPDomainSize.
187216 SizeType32 window = mCacheManager ->getBlockManager ().getPoolWindowSize (0 );
188217 std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocksPerWindow;
189218 inputKvCacheBlocksPerWindow.emplace (window, inputKvCacheBlocks);
@@ -203,7 +232,7 @@ void MLACacheFormatter::format(TransferSession& session)
203232
204233 TLLM_CUDA_CHECK (cudaSetDevice (deviceId));
205234 auto startTime = std::chrono::steady_clock::now ();
206- auto cacheIdx = processIdx % pPDomainSize;
235+ auto cacheIdx = processIdx % ( pPDomainSize * cPDomainSize) ;
207236 if (cacheIdx < bufferCoverTargetNum)
208237 {
209238 size_t size = outputSplitCaches.at (cacheIdx)->getSizeInBytes ();
@@ -259,7 +288,8 @@ void MLACacheFormatter::format(TransferSession& session)
259288 else
260289 {
261290 // concurrency num
262- auto concurrencyNum = std::min (std::max (static_cast <size_t >(1 ), bufferCoverTargetNum), pPDomainSize);
291+ auto concurrencyNum
292+ = std::min (std::max (static_cast <size_t >(1 ), bufferCoverTargetNum), pPDomainSize * cPDomainSize);
263293
264294 auto remainSendNum = connections.size ();
265295
@@ -307,9 +337,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
307337 auto & bufferManager = session.getBufferManager ();
308338 auto arrivalTime = llmRequest.getPerfMetrics ().timingMetrics .arrivalTime ;
309339 bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point ();
310- // diff start
311340 auto pickUpConnections = pickRecvConnections (connections.size (), selfConfig, selfIdx, destConfig);
312- // diff end
313341 auto blockRange = getBlockRangeForReceiving (mCacheManager , llmRequest);
314342 std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
315343 std::vector<runtime::ITensor::SharedPtr> outputBuffers;
@@ -364,23 +392,24 @@ void MLACacheFormatter::unformat(TransferSession& session)
364392 cacheBufferId = mCacheTransBufferManager ->assignBufferIndexForRecv ();
365393 }
366394
367- auto cacheBlockSize = outputBuffers.at (0 )->getSize ();
368-
369395 auto targetNum = pickUpConnections.size ();
370- auto targetInfo = executor::kv_cache::targetIRanks (destConfig, selfConfig, selfIdx);
371- auto ppRank = selfIdx
372- / (selfConfig.getParallelConfig ().mTensorParallelism * selfConfig.getParallelConfig ().mContextParallelism );
373- auto selfAttentionLayerNum = selfConfig.getParallelConfig ().mAttentionLayerNumPerPP .at (ppRank);
374- TLLM_CHECK_WITH_INFO (selfAttentionLayerNum != 0 , " selfAttentionLayerNum should not be 0" );
375396
376397 auto getBufferSizeForTarget = [&]()
377398 {
399+ auto const targetInfo = executor::kv_cache::targetIRanks (destConfig, selfConfig, selfIdx);
400+ auto const cacheBlockSize = outputBuffers.at (0 )->getSize ();
401+ auto const ppRank = selfIdx
402+ / (selfConfig.getParallelConfig ().mTensorParallelism
403+ * selfConfig.getParallelConfig ().mContextParallelism );
404+ auto const selfAttentionLayerNum = selfConfig.getParallelConfig ().mAttentionLayerNumPerPP .at (ppRank);
405+ TLLM_CHECK_WITH_INFO (selfAttentionLayerNum != 0 , " selfAttentionLayerNum should not be 0" );
378406 std::vector<size_t > bufferEleSizes (targetNum, 0 );
379- auto cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum;
407+ auto const cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum;
380408 for (size_t i = 0 ; i < targetNum; i++)
381409 {
382- auto layerNum = targetInfo.getPeerPPDomainLayerNum (static_cast <SizeType32>(pickUpConnections[i]));
383- bufferEleSizes[i] = cacheSizePerLayer * layerNum;
410+ auto const peerAttentionLayerNum
411+ = targetInfo.getPeerPPDomainLayerNum (static_cast <SizeType32>(pickUpConnections[i]));
412+ bufferEleSizes[i] = cacheSizePerLayer * peerAttentionLayerNum;
384413 }
385414 return bufferEleSizes;
386415 };
@@ -506,9 +535,10 @@ void MLACacheFormatter::unformat(TransferSession& session)
506535 outputCachesPerWindow.emplace (window, outputBuffers);
507536 NVTX3_SCOPED_RANGE (formatInputConcatenate);
508537
509- // recvSplitCaches size == ppdomainsize
538+ // recvSplitCaches size == ppdomainsize * cpdomainsize.
510539 executor::kv_cache::concatKvCacheV2Dispatch (
511540 recvSplitCaches, outputCachesPerWindow, destConfig, selfConfig, selfIdx, bufferManager);
541+ bufferManager.getStream ().synchronize ();
512542 }
513543 bufferManager.getStream ().synchronize ();
514544 }
@@ -581,14 +611,6 @@ void MLACacheFormatter::unformat(TransferSession& session)
581611 TLLM_LOG_WARNING (" MLACacheFormatter::inquireSupport: TP size must be divisible by DP size" );
582612 return false ;
583613 }
584- if (selfConfig.getParallelConfig ().mContextParallelism != 1
585- || destConfig.getParallelConfig ().mContextParallelism != 1 )
586- {
587- TLLM_LOG_WARNING (
588- " MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d)." ,
589- selfConfig.getParallelConfig ().mContextParallelism , destConfig.getParallelConfig ().mContextParallelism );
590- return false ;
591- }
592614 if (destConfig.getParallelConfig ().mEnableAttentionDP
593615 && (destConfig.getParallelConfig ().mTensorParallelism % destConfig.getParallelConfig ().mDPsize != 0 ))
594616 {
0 commit comments