Skip to content

Commit 82078cb

Browse files
committed
Fixing test
Signed-off-by: Patrice Castonguay <[email protected]>
1 parent a73d28d commit 82078cb

File tree

7 files changed

+97
-81
lines changed

7 files changed

+97
-81
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ class KVCacheEventManager
105105

106106
/// @brief The period in milliseconds to gather attention DP events across rank
107107
SizeType32 mAttentionDpEventsGatherPeriodMs;
108+
109+
/// @brief MPI communicator for attention DP
110+
std::unique_ptr<tensorrt_llm::mpi::MpiComm> mMpiComm;
108111
};
109112

110113
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional
3535
, mAttentionDpSize{attentionDpSize}
3636
, mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs)
3737
{
38-
3938
TLLM_CHECK(mMaxSize > 0);
4039
if (mAttentionDpRank)
4140
{
@@ -49,6 +48,8 @@ KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional
4948
// Need to increase size
5049
mMaxSize *= mAttentionDpSize.value();
5150
}
51+
// Create a communicator to be used for event exchange
52+
mMpiComm = std::make_unique<tensorrt_llm::mpi::MpiComm>(COMM_SESSION.split(0, mAttentionDpRank.value()));
5253
}
5354
else
5455
{
@@ -162,37 +163,49 @@ void KVCacheEventManager::exchangeAttentionDpThread()
162163
while (true)
163164
{
164165
TLLM_CHECK(mAttentionDpRank);
166+
167+
// Check if any of the ranks have been shutdown
168+
int32_t numFinished = 0;
169+
int32_t finished = mRun ? 0 : 1;
170+
mMpiComm->allreduce(&finished, &numFinished, 1, mpi::MpiType::kINT32, mpi::MpiOp::SUM);
171+
if (numFinished > 0)
172+
{
173+
TLLM_LOG_INFO("One of the rank has been shut down, exiting");
174+
break;
175+
}
176+
165177
// If we are not rank 0, send events to rank 0
166178
if (mAttentionDpRank.value() != 0)
167179
{
168180
std::vector<char> serializedEvents;
181+
uint64_t numEvents = 0;
169182
{
170183
std::unique_lock<std::mutex> lck(mEventsMutex);
171184
serializedEvents = executor::Serialization::serialize(mEvents);
185+
numEvents = mEvents.size();
172186
mEvents.clear();
173187
}
174-
uint64_t vecSize = serializedEvents.size();
175-
COMM_SESSION.send(&vecSize, 1, mpi::MpiType::kUINT64, 0, mpi::MpiTag::kKvCacheEventSize);
176-
COMM_SESSION.send(
177-
serializedEvents.data(), serializedEvents.size(), mpi::MpiType::kCHAR, 0, mpi::MpiTag::kKvCacheEvent);
188+
uint64_t vecSize = numEvents > 0 ? serializedEvents.size() : 0;
189+
mMpiComm->send(&vecSize, 1, mpi::MpiType::kUINT64, 0, mpi::MpiTag::kKvCacheEventSize);
190+
if (vecSize > 0)
191+
{
192+
mMpiComm->send(serializedEvents.data(), serializedEvents.size(), mpi::MpiType::kCHAR, 0,
193+
mpi::MpiTag::kKvCacheEvent);
194+
}
178195
}
179196
else
180197
{
181198
TLLM_CHECK(mAttentionDpSize.has_value());
182199
// Loop until have received events from all ranks
183-
int32_t numRecvs = 0;
184-
while (numRecvs < mAttentionDpSize.value() - 1)
200+
for (int rank = 1; rank < mAttentionDpSize.value(); ++rank)
185201
{
186-
MPI_Status probeStatus;
187-
if (COMM_SESSION.iprobe(MPI_ANY_SOURCE, mpi::MpiTag::kKvCacheEvent, &probeStatus))
202+
uint64_t vecSize{0};
203+
mMpiComm->recv(&vecSize, 1, mpi::MpiType::kUINT64, rank, mpi::MpiTag::kKvCacheEventSize);
204+
if (vecSize > 0)
188205
{
189-
uint64_t vecSize{0};
190-
COMM_SESSION.recv(
191-
&vecSize, 1, mpi::MpiType::kUINT64, probeStatus.MPI_SOURCE, mpi::MpiTag::kKvCacheEventSize);
192-
193206
std::vector<char> serializedEvents(vecSize);
194-
COMM_SESSION.recv(serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, probeStatus.MPI_SOURCE,
195-
mpi::MpiTag::kKvCacheEvent);
207+
mMpiComm->recv(
208+
serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, rank, mpi::MpiTag::kKvCacheEvent);
196209

197210
// Deserialize the events and add them to the local queue
198211
auto rankEvents = executor::Serialization::deserializeKVCacheEvents(serializedEvents);
@@ -201,11 +214,10 @@ void KVCacheEventManager::exchangeAttentionDpThread()
201214
mEvents.insert(mEvents.end(), rankEvents.begin(), rankEvents.end());
202215
mEmptyCV.notify_one();
203216
}
204-
numRecvs++;
205217
}
206218
}
207-
std::this_thread::sleep_for(std::chrono::milliseconds(mAttentionDpEventsGatherPeriodMs));
208219
}
220+
std::this_thread::sleep_for(std::chrono::milliseconds(mAttentionDpEventsGatherPeriodMs));
209221
}
210222
#else
211223
TLLM_THROW("Multi device support is disabled.");

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
325325
.def_static("hash", &tbk::BlockKeyHasher::hash, nb::arg("block_key"), nb::arg("parent_hash") = 0);
326326

327327
nb::class_<tbk::KVCacheEventManager>(m, "KVCacheEventManager")
328-
.def(nb::init<size_t>(), nb::arg("max_kv_event_entries"));
328+
.def(nb::init<size_t, std::optional<SizeType32>, std::optional<SizeType32>, SizeType32>(),
329+
nb::arg("max_kv_event_entries"), nb::arg("attention_dp_rank"), nb::arg("attention_dp_size"),
330+
nb::arg("attention_dp_events_gather_period_ms"));
329331

330332
nb::class_<tbk::BaseKVCacheManager, PyKvCacheManager>(m, "BaseKVCacheManager")
331333
.def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, nb::arg("config"),

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
321321
.def_static("hash", &tbk::BlockKeyHasher::hash, py::arg("block_key"), py::arg("parent_hash") = 0);
322322

323323
py::class_<tbk::KVCacheEventManager, std::shared_ptr<tbk::KVCacheEventManager>>(m, "KVCacheEventManager")
324-
.def(py::init<size_t>(), py::arg("max_kv_event_entries"));
324+
.def(py::init<size_t, std::optional<SizeType32>, std::optional<SizeType32>, SizeType32>(),
325+
py::arg("max_kv_event_entries"), py::arg("attention_dp_rank"), py::arg("attention_dp_size"),
326+
py::arg("attention_dp_events_gather_period_ms"));
325327

326328
py::classh<tbk::BaseKVCacheManager, PyKvCacheManager>(m, "BaseKVCacheManager")
327329
.def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, py::arg("config"),

cpp/tensorrt_llm/pybind/executor/executorConfig.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ void initConfigBindings(pybind11::module_& m)
108108
};
109109
auto kvCacheConfigSetstate = [](py::tuple const& state)
110110
{
111-
if (state.size() != 13)
111+
if (state.size() != 14)
112112
{
113113
throw std::runtime_error("Invalid state!");
114114
}

cpp/tests/unit_tests/batch_manager/kvCacheEventManagerTest.cpp

Whitespace-only changes.

tests/unittest/llmapi/test_llm_kv_cache_events.py

Lines changed: 58 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,11 @@ async def main():
147147
asyncio.run(main())
148148

149149

150-
def test_llm_kv_events_api():
151-
llm = create_llm()
152-
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
153-
154-
requests = []
155-
for i in range(3):
156-
input_tokens = list(range(127 + i))[i:]
157-
requests.append(input_tokens)
150+
def check_events(llm,
151+
requests,
152+
sampling_params,
153+
scheduling_params=None,
154+
attention_dp_rank=None):
158155

159156
_ = llm.generate(requests[0], sampling_params=sampling_params)
160157
events1 = llm.get_kv_cache_events(5)
@@ -163,52 +160,95 @@ def test_llm_kv_events_api():
163160
event = events1.pop(0) # created event
164161
while events1:
165162
event = events1.pop(0)
163+
print("event1:", event)
166164
if event:
167165
assert event["event_id"] == 1
168166
assert event["data"]["type"] == "stored"
169167
assert len(event["data"]["blocks"]) == 5
168+
if attention_dp_rank:
169+
assert event["data"]["attention_dp_rank"] == attention_dp_rank
170170

171171
_ = llm.generate(requests[1], sampling_params=sampling_params)
172172
events2 = llm.get_kv_cache_events(5)
173173

174174
while events2:
175175
event = events2.pop(0)
176+
print("event2:", event)
176177
if event:
177178
if event["event_id"] == 2:
178179
# 2 removed events needed
179180
# should be a removed event to make space for context block
180181
assert event["data"]["type"] == "removed"
181182
assert event["data"]["block_hashes"]
183+
if attention_dp_rank:
184+
assert event["data"][
185+
"attention_dp_rank"] == attention_dp_rank
182186
elif event["event_id"] == 3:
183187
assert event["data"]["type"] == "removed"
184188
assert event["data"]["block_hashes"]
189+
if attention_dp_rank:
190+
assert event["data"][
191+
"attention_dp_rank"] == attention_dp_rank
185192
# stored event for 2nd request
186193
elif event["event_id"] == 4:
187194
assert event["data"]["type"] == "stored"
188195
assert len(event["data"]["blocks"]) == 5
196+
if attention_dp_rank:
197+
assert event["data"][
198+
"attention_dp_rank"] == attention_dp_rank
189199

190200
_ = llm.generate(requests[2], sampling_params=sampling_params)
191201
events3 = llm.get_kv_cache_events(5)
192202

193203
while events3:
194204
event = events3.pop(0)
205+
print("event3:", event)
195206
if event:
196207
if event["event_id"] == 5:
197208
assert event["data"]["type"] == "removed"
198209
assert event["data"]["block_hashes"]
210+
if attention_dp_rank:
211+
assert event["data"][
212+
"attention_dp_rank"] == attention_dp_rank
199213
elif event["event_id"] == 6:
200214
assert event["data"]["type"] == "removed"
201215
assert event["data"]["block_hashes"]
216+
if attention_dp_rank:
217+
assert event["data"][
218+
"attention_dp_rank"] == attention_dp_rank
202219
elif event["event_id"] == 7:
203220
assert event["data"]["type"] == "stored"
204221
assert len(event["data"]["blocks"]) == 5
222+
if attention_dp_rank:
223+
assert event["data"][
224+
"attention_dp_rank"] == attention_dp_rank
205225

206226
# no more events after request is finished
207227
assert not llm.get_kv_cache_events(5)
208228

209229

230+
def test_llm_kv_events_api():
231+
llm = create_llm()
232+
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
233+
234+
requests = []
235+
for i in range(3):
236+
input_tokens = list(range(127 + i))[i:]
237+
requests.append(input_tokens)
238+
239+
check_events(llm, requests, sampling_params)
240+
241+
210242
@skip_single_gpu
211243
def test_llm_api_attention_dp_kv_events():
244+
245+
kvcache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
246+
event_buffer_max_size=1024,
247+
attention_dp_events_gather_period_ms=10,
248+
enable_block_reuse=True,
249+
onboard_blocks=True,
250+
max_tokens=256)
251+
212252
llm = LLM(model=llama_model_path,
213253
tensor_parallel_size=2,
214254
enable_attention_dp=True,
@@ -217,59 +257,16 @@ def test_llm_api_attention_dp_kv_events():
217257

218258
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
219259

220-
requests = []
221-
for i in range(3):
222-
input_tokens = list(range(127 + i))[i:]
223-
requests.append(input_tokens)
224-
225-
_ = llm.generate(requests[0], sampling_params=sampling_params)
226-
events1 = llm.get_kv_cache_events(5)
227-
228-
# Should have 1 stored event and 1 created event
229-
event = events1.pop(0) # created event
230-
while events1:
231-
event = events1.pop(0)
232-
if event:
233-
assert event["event_id"] == 1
234-
assert event["data"]["type"] == "stored"
235-
assert event["attention_dp_rank"] == 0
236-
assert event["window_size"] == 32
237-
assert len(event["data"]["blocks"]) == 5
260+
for attention_dp_rank in range(2):
261+
requests = []
262+
for i in range(3):
263+
input_tokens = list(range(127 + i))[i:]
264+
requests.append(input_tokens)
238265

239-
_ = llm.generate(requests[1], sampling_params=sampling_params)
240-
events2 = llm.get_kv_cache_events(5)
266+
scheduling_params = SchedulingParams(
267+
attention_dp_rank=attention_dp_rank, attention_dp_relax=False)
241268

242-
while events2:
243-
event = events2.pop(0)
244-
if event:
245-
if event["event_id"] == 2:
246-
# 2 removed events needed
247-
# should be a removed event to make space for context block
248-
assert event["data"]["type"] == "removed"
249-
assert event["data"]["block_hashes"]
250-
elif event["event_id"] == 3:
251-
assert event["data"]["type"] == "removed"
252-
assert event["data"]["block_hashes"]
253-
# stored event for 2nd request
254-
elif event["event_id"] == 4:
255-
assert event["data"]["type"] == "stored"
256-
assert len(event["data"]["blocks"]) == 5
269+
check_events(llm, requests, sampling_params, scheduling_params,
270+
attention_dp_rank)
257271

258-
#_ = llm.generate(requests[2], sampling_params=sampling_params)
259-
#events3 = llm.get_kv_cache_events(5)
260-
261-
#while events3:
262-
# event = events3.pop(0)
263-
# if event:
264-
# if event["event_id"] == 5:
265-
# assert event["data"]["type"] == "removed"
266-
# assert event["data"]["block_hashes"]
267-
# elif event["event_id"] == 6:
268-
# assert event["data"]["type"] == "removed"
269-
# assert event["data"]["block_hashes"]
270-
# elif event["event_id"] == 7:
271-
# assert event["data"]["type"] == "stored"
272-
# assert len(event["data"]["blocks"]) == 5
273-
274-
## no more events after request is finished
275-
#assert not llm.get_kv_cache_events(5)
272+
time.sleep(5)

0 commit comments

Comments
 (0)