@@ -166,6 +166,9 @@ void CacheFormatter::format(TransferSession& session)
166166 auto const numPools = blockManager.getNumPools ();
167167 // TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
168168
169+ auto lastTokenTime = llmRequest.getPerfMetrics ().timingMetrics .lastTokenTime ;
170+ bool recordDelay = lastTokenTime != std::chrono::steady_clock::time_point ();
171+
169172 bool layerWise = common::getEnvDisaggLayerwise () && numPools == 1 ;
170173 if (layerWise)
171174 {
@@ -350,9 +353,14 @@ void CacheFormatter::format(TransferSession& session)
350353 }
351354
352355 auto endTime = std::chrono::steady_clock::now ();
356+ double delay = 0.0 ;
357+ if (recordDelay)
358+ {
359+ delay = std::chrono::duration<double , std::milli>(startTime - lastTokenTime).count ();
360+ }
353361 double cacheTransferTime
354362 = std::max (0.0 , std::chrono::duration<double , std::milli>(endTime - startTime).count ());
355- kvCacheMeasureHelper.appendKVCacheTransfer (llmRequest.mRequestId , cacheTransferTime, size);
363+ kvCacheMeasureHelper.appendKVCacheTransfer (llmRequest.mRequestId , delay, cacheTransferTime, size);
356364 };
357365
358366 if (connections.size () > 1 )
@@ -408,16 +416,19 @@ void CacheFormatter::unformat(TransferSession& session)
408416{
409417 NVTX3_SCOPED_RANGE (CacheFormatter_unformat);
410418 auto const & llmRequest = session.getLlmRequest ();
419+ auto const ctxReqId = llmRequest.getContextPhaseParams ().value ().getReqId ();
411420 TLLM_LOG_DEBUG (mpi::MpiComm::world ().getRank (),
412- " Start receiving KV cache for request ID: %ld, context request ID: %ld." , llmRequest.mRequestId ,
413- llmRequest.getContextPhaseParams ().value ().getReqId ());
421+ " Start receiving KV cache for request ID: %ld, context request ID: %ld." , llmRequest.mRequestId , ctxReqId);
414422 auto const & connections = session.getConnections ();
415423 auto const & selfConfig = session.getSelfState ().getCacheState ().value ();
416424 auto const & destConfig = session.getOtherState ().getCacheState ().value ();
417425 auto const selfIdx = session.getSelfState ().getCommState ().value ().getSelfIdx ();
418426 auto & bufferManager = session.getBufferManager ();
419427 auto blockRange = getBlockRangeForReceiving (mCacheManager , llmRequest);
420428
429+ auto arrivalTime = llmRequest.getPerfMetrics ().timingMetrics .arrivalTime ;
430+ bool recordDelay = arrivalTime != std::chrono::steady_clock::time_point ();
431+
421432 auto pickUpConnections = pickRecvConnections (connections.size (), selfConfig, selfIdx, destConfig);
422433
423434 TLLM_LOG_DEBUG (" pickUpConnections size: %d connections size: %d" , pickUpConnections.size (), connections.size ());
@@ -546,7 +557,7 @@ void CacheFormatter::unformat(TransferSession& session)
546557 }
547558 TLLM_LOG_DEBUG (mpi::MpiComm::world ().getRank (),
548559 " End receiving KV cache for request ID: %ld, context request ID: %ld." , llmRequest.mRequestId ,
549- llmRequest. getContextPhaseParams (). value (). getReqId () );
560+ ctxReqId );
550561 return ;
551562 }
552563 // legacyPath: context executor rank only send data to one gen executor rank. it sends multiple cache
@@ -634,6 +645,8 @@ void CacheFormatter::unformat(TransferSession& session)
634645 TLLM_CUDA_CHECK (cudaSetDevice (deviceId));
635646 TLLM_CHECK (pickUpConnections.size () > processIdx);
636647 TLLM_CHECK (recvSplitCaches.size () > processIdx);
648+ auto startTime = std::chrono::steady_clock::now ();
649+ size_t size = 0 ;
637650 if (legacyPath)
638651 {
639652 size_t idx = processIdx * blockNum;
@@ -645,6 +658,7 @@ void CacheFormatter::unformat(TransferSession& session)
645658 size_t recvBufferIdx = blockIdx * pickUpConnections.size () + commIdx;
646659 llmRequest.updateKvCacheSize ((*recvSplitCaches[recvBufferIdx]).getSizeInBytes ());
647660 auto & buffer = recvSplitCaches.at (recvBufferIdx);
661+ size += buffer->getSizeInBytes ();
648662 session.recv (pickUpConnections[processIdx], buffer->data (), buffer->getSizeInBytes ());
649663 idx++;
650664 }
@@ -655,6 +669,7 @@ void CacheFormatter::unformat(TransferSession& session)
655669 {
656670 llmRequest.updateKvCacheSize ((*recvSplitCaches.at (processIdx)).getSizeInBytes ());
657671 auto & buffer = recvSplitCaches[processIdx];
672+ size = buffer->getSizeInBytes ();
658673 session.recv (pickUpConnections[processIdx], buffer->data (), buffer->getSizeInBytes ());
659674 }
660675 else if (bufferCoverTargetNum > 0 )
@@ -663,6 +678,7 @@ void CacheFormatter::unformat(TransferSession& session)
663678 + remainNoCoverTargetNum; // caches.at(recvBufferIdx) is allocated by cudaMalloc
664679 llmRequest.updateKvCacheSize ((*recvSplitCaches.at (recvBufferIdx)).getSizeInBytes ());
665680 auto & buffer = recvSplitCaches.at (recvBufferIdx);
681+ size = buffer->getSizeInBytes ();
666682 session.recv (pickUpConnections[processIdx], buffer->data (), buffer->getSizeInBytes ());
667683 bufferManager.copy (*recvSplitCaches.at (recvBufferIdx), *recvSplitCaches[processIdx]);
668684 bufferManager.getStream ().synchronize ();
@@ -679,6 +695,7 @@ void CacheFormatter::unformat(TransferSession& session)
679695 auto recvSlice = runtime::ITensor::slice (preAllocRecvBuffer, 0 , recvSize);
680696 auto copySlice = runtime::ITensor::slice (
681697 recvSplitCaches[processIdx], targetBufferSize - remainRecvSize, recvSize);
698+ size += recvSlice->getSizeInBytes ();
682699 llmRequest.updateKvCacheSize ((*recvSlice).getSizeInBytes ());
683700 session.recv (pickUpConnections[processIdx], recvSlice->data (), recvSlice->getSizeInBytes ());
684701 bufferManager.copy (*recvSlice, *copySlice);
@@ -687,6 +704,15 @@ void CacheFormatter::unformat(TransferSession& session)
687704 }
688705 }
689706 }
707+ auto endTime = std::chrono::steady_clock::now ();
708+ double delay = 0.0 ;
709+ if (recordDelay)
710+ {
711+ delay = std::chrono::duration<double , std::milli>(startTime - arrivalTime).count ();
712+ }
713+ double cacheTransferTime
714+ = std::max (0.0 , std::chrono::duration<double , std::milli>(endTime - startTime).count ());
715+ kvCacheMeasureHelper.appendKVCacheTransfer (ctxReqId, delay, cacheTransferTime, size);
690716 };
691717 if (pickUpConnections.size () > 1 )
692718 {
@@ -814,6 +840,8 @@ void CacheFormatter::unformat(TransferSession& session)
814840 if (selfConfig.getModelConfig ().mNbKvHeadsPerLayer .size () != destConfig.getModelConfig ().mNbKvHeadsPerLayer .size ())
815841 {
816842 TLLM_LOG_WARNING (" CacheFormatter::inquireSupport: only support same number of layers" );
843+ TLLM_LOG_WARNING (" self: %zu dest %zu" , selfConfig.getModelConfig ().mNbKvHeadsPerLayer .size (),
844+ destConfig.getModelConfig ().mNbKvHeadsPerLayer .size ());
817845 return false ;
818846 }
819847 int selfNumLayers = selfConfig.getModelConfig ().mNbKvHeadsPerLayer .size ();
0 commit comments