1818#include " tensorrt_llm/batch_manager/kvCacheEventManager.h"
1919#include " tensorrt_llm/batch_manager/kvCacheManager.h"
2020#include " tensorrt_llm/executor/executor.h"
21+ #include " tensorrt_llm/executor/serialization.h"
22+ #include " tensorrt_llm/runtime/utils/mpiUtils.h"
2123
2224namespace tle = tensorrt_llm::executor;
2325
2426namespace tensorrt_llm ::batch_manager::kv_cache_manager
2527{
2628
2729KVCacheEventManager::KVCacheEventManager (size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank,
28- std::optional<SizeType32> attentionDpSize, std::optional<attentionDpSize> ppSize )
30+ std::optional<SizeType32> attentionDpSize, SizeType32 attentionDpEventsGatherPeriodMs )
2931 : mRun {true }
3032 , mMaxSize {maxKVEventEntries}
3133 , mEventId {0 }
3234 , mAttentionDpRank {attentionDpRank}
3335 , mAttentionDpSize {attentionDpSize}
36+ , mAttentionDpEventsGatherPeriodMs (attentionDpEventsGatherPeriodMs)
3437{
38+
3539 TLLM_CHECK (mMaxSize > 0 );
3640 if (mAttentionDpRank )
3741 {
3842 TLLM_CHECK_WITH_INFO (
3943 mAttentionDpSize .has_value (), " If attention DP rank is set, the attention DP size must also be set" );
40- TLLM_CHECK (ppSize.has_value ());
41- TLLM_CHECK_WITH_INFO (ppSize.value () == 1 , " Events with attention DP are not supported with PP > 1" );
4244 TLLM_CHECK_WITH_INFO (mAttentionDpRank .value () < mAttentionDpSize .value (),
4345 " Attention DP rank must be less than attention DP size" );
46+ if (mAttentionDpRank .value () == 0 )
47+ {
48+ // Rank 0 will gather events from all other ranks
49+ // Need to increase size
50+ mMaxSize *= mAttentionDpSize .value ();
51+ }
4452 }
4553 else
4654 {
4755 TLLM_CHECK_WITH_INFO (
48- !mAttentionDpSize .has_value (), " If attention DP size is set, the attention DP rank must also be set" );
56+ !mAttentionDpSize .has_value (), " If attention DP rank is not set, the attention DP size must not be set" );
4957 }
50- // mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this));
5158 mWorkerThread = std::thread ([this ]() { this ->worker (); });
52- mExchangeAttentionDpThread = std::thread ([this ]() { this ->exchangeAttentionDpEvents (); });
59+ if (mAttentionDpRank )
60+ {
61+ mExchangeAttentionDpThread = std::thread ([this ]() { this ->exchangeAttentionDpThread (); });
62+ }
5363};
5464
5565KVCacheEventManager::~KVCacheEventManager ()
@@ -58,7 +68,10 @@ KVCacheEventManager::~KVCacheEventManager()
5868 mPendingEmptyCV .notify_all ();
5969 mEmptyCV .notify_all ();
6070 mWorkerThread .join ();
61- mAttentionDpExchangeThread .join ();
71+ if (mAttentionDpRank )
72+ {
73+ mExchangeAttentionDpThread .join ();
74+ }
6275}
6376
6477void KVCacheEventManager::enqueueCreatedEvent (
@@ -84,7 +97,7 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks
8497 for (auto const & block : blocks)
8598 {
8699 data.blocks .emplace_back (block->getHash (), block->getUniqueTokens (), block->getBlockKey ().loraTaskId ,
87- block->isPrimary () ? kPrimaryLevel : kSecondaryLevel , block->getPriority (), mAttentionDpRank );
100+ block->isPrimary () ? kPrimaryLevel : kSecondaryLevel , block->getPriority ());
88101 }
89102
90103 enqueueEvent ({mEventId ++, data, windowSize, mAttentionDpRank });
@@ -100,7 +113,7 @@ void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32
100113 }
101114 else
102115 {
103- enqueueEvent ({mEventId ++, tle::KVCacheRemovedData{{block->getHash ()}}, windowSize});
116+ enqueueEvent ({mEventId ++, tle::KVCacheRemovedData{{block->getHash ()}}, windowSize, mAttentionDpRank });
104117 }
105118}
106119
@@ -136,28 +149,27 @@ void KVCacheEventManager::flush()
136149 auto eventQueue = std::exchange (mEventQueue , {});
137150 std::unique_lock<std::mutex> lck (mPendingEventsMutex );
138151 mPendingEvents .push_back (std::move (eventQueue));
139- // If we have events, we need to notify the worker thread to process them
140152 mPendingEmptyCV .notify_one ();
141153}
142154
143155void KVCacheEventManager::exchangeAttentionDpThread ()
144156{
145- int32_t pollPeriodMs = 5 ;
146157 while (true )
147158 {
148- // If we are not rank 0, send events asynchronously
159+ TLLM_CHECK (mAttentionDpRank );
160+ // If we are not rank 0, send events to rank 0
149161 if (mAttentionDpRank .value () != 0 )
150162 {
151163 std::vector<char > serializedEvents;
152164 {
153165 std::unique_lock<std::mutex> lck (mEventsMutex );
154- serializedEvents = Serialization::serialize (mEvents );
166+ serializedEvents = executor:: Serialization::serialize (mEvents );
155167 mEvents .clear ();
156168 }
157169 uint64_t vecSize = serializedEvents.size ();
158- COMM_SESSION.send (&vecSize, 1 , MpiType::kUINT64 , 0 , MpiTag::kKVCacheEventSize );
170+ COMM_SESSION.send (&vecSize, 1 , mpi:: MpiType::kUINT64 , 0 , mpi:: MpiTag::kKvCacheEventSize );
159171 COMM_SESSION.send (
160- serializedEvents.data (), serializedEvents.size (), MpiType::kCHAR , 0 , MpiTag::kKVCacheEvent );
172+ serializedEvents.data (), serializedEvents.size (), mpi:: MpiType::kCHAR , 0 , mpi:: MpiTag::kKvCacheEvent );
161173 }
162174 else
163175 {
@@ -167,18 +179,18 @@ void KVCacheEventManager::exchangeAttentionDpThread()
167179 while (numRecvs < mAttentionDpSize .value () - 1 )
168180 {
169181 MPI_Status probeStatus;
170- if (COMM_SESSION.iprobe (MPI_ANY_SOURCE, MpiTag::kKVCacheEvent , &status ))
182+ if (COMM_SESSION.iprobe (MPI_ANY_SOURCE, mpi:: MpiTag::kKvCacheEvent , &probeStatus ))
171183 {
172- uint64_t vecSize;
184+ uint64_t vecSize{ 0 } ;
173185 COMM_SESSION.recv (
174- &vecSize, 1 , mpi::MpiType::kUINT64 , probeStatus.MPI_SOURCE , mpi::MpiTag::kKVCacheEventSize );
186+ &vecSize, 1 , mpi::MpiType::kUINT64 , probeStatus.MPI_SOURCE , mpi::MpiTag::kKvCacheEventSize );
175187
176188 std::vector<char > serializedEvents (vecSize);
177- COMM_SESSION.recv (& serializedEvents.data (), vecSize, mpi::MpiType::kCHAR , probeStatus.MPI_SOURCE ,
178- mpi::MpiTag::kKVCacheEvent );
189+ COMM_SESSION.recv (serializedEvents.data (), vecSize, mpi::MpiType::kCHAR , probeStatus.MPI_SOURCE ,
190+ mpi::MpiTag::kKvCacheEvent );
179191
180192 // Deserialize the events and add them to the local queue
181- auto rankEvents = Serialization::deserializeKVCacheEvents (serializedEvents);
193+ auto rankEvents = executor:: Serialization::deserializeKVCacheEvents (serializedEvents);
182194 {
183195 std::unique_lock<std::mutex> lck (mEventsMutex );
184196 mEvents .insert (mEvents .end (), rankEvents.begin (), rankEvents.end ());
@@ -187,47 +199,47 @@ void KVCacheEventManager::exchangeAttentionDpThread()
187199 numRecvs++;
188200 }
189201 }
190- std::this_thread::sleep_for (std::chrono::milliseconds (pollPeriodMs ));
202+ std::this_thread::sleep_for (std::chrono::milliseconds (mAttentionDpEventsGatherPeriodMs ));
191203 }
192204 }
205+ }
193206
194- void KVCacheEventManager::worker ()
195- {
207+ void KVCacheEventManager::worker ()
208+ {
196209
197- while (true )
210+ while (true )
211+ {
212+ std::deque<tle::KVCacheEvent> events;
198213 {
199- std::deque<tle::KVCacheEvent> events;
214+ std::unique_lock<std::mutex> pendingLock (mPendingEventsMutex );
215+ mPendingEmptyCV .wait (pendingLock, [this ] { return !mPendingEvents .empty () || !mRun ; });
216+ if (!mRun )
200217 {
201- std::unique_lock<std::mutex> pendingLock (mPendingEventsMutex );
202- mPendingEmptyCV .wait (pendingLock, [this ] { return !mPendingEvents .empty () || !mRun ; });
203- if (!mRun )
204- {
205- return ;
206- }
207- events = mPendingEvents .front ();
208- mPendingEvents .pop_front ();
218+ return ;
209219 }
220+ events = mPendingEvents .front ();
221+ mPendingEvents .pop_front ();
222+ }
210223
211- std::unique_lock<std::mutex> lck (mEventsMutex );
224+ std::unique_lock<std::mutex> lck (mEventsMutex );
212225
213- SizeType32 elementsToRemove = mEvents .size () + events.size () - mMaxSize ;
226+ SizeType32 elementsToRemove = mEvents .size () + events.size () - mMaxSize ;
214227
215- // First, take elements from mEvents since they are the oldest.
216- if (elementsToRemove > 0 )
217- {
218- SizeType32 numRemoved = std::min (static_cast <SizeType32>(mEvents .size ()), elementsToRemove);
219- mEvents .erase (mEvents .begin (), mEvents .begin () + numRemoved);
220- elementsToRemove -= numRemoved;
221- TLLM_LOG_WARNING (
222- " The event queue has reached the max size of %d. Events have been discarded." , mMaxSize );
223- }
228+ // First, take elements from mEvents since they are the oldest.
229+ if (elementsToRemove > 0 )
230+ {
231+ SizeType32 numRemoved = std::min (static_cast <SizeType32>(mEvents .size ()), elementsToRemove);
232+ mEvents .erase (mEvents .begin (), mEvents .begin () + numRemoved);
233+ elementsToRemove -= numRemoved;
234+ TLLM_LOG_WARNING (" The event queue has reached the max size of %d. Events have been discarded." , mMaxSize );
235+ }
224236
225- // If there's still too many events, take from the front of the events queue.
226- mEvents .insert (mEvents .end (), events.begin () + std::max (0 , elementsToRemove), events.end ());
237+ // If there's still too many events, take from the front of the events queue.
238+ mEvents .insert (mEvents .end (), events.begin () + std::max (0 , elementsToRemove), events.end ());
227239
228- // Notify the empty condition variable to wake up any waiting threads
229- mEmptyCV .notify_one ();
230- }
240+ // Notify the empty condition variable to wake up any waiting threads
241+ mEmptyCV .notify_one ();
231242 }
243+ }
232244
233245} // namespace tensorrt_llm::batch_manager::kv_cache_manager
0 commit comments