Skip to content

Commit de47282

Browse files
authored
[TRTLLM-6637][feat] Resolve KV cache divergence issue (#6628)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent d643aef commit de47282

File tree

8 files changed

+109
-48
lines changed

8 files changed

+109
-48
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -828,8 +828,10 @@ class GenericLlmRequest
828828
// for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase
829829
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
830830
: LlmRequestState::kCONTEXT_INIT;
831-
mContextCurrentPosition = 0;
832-
mPrepopulatedPromptLen = 0;
831+
mContextCurrentPositionTarget = 0;
832+
mContextCurrentPositionDraft = 0;
833+
mPrepopulatedPromptLenTarget = 0;
834+
mPrepopulatedPromptLenDraft = 0;
833835
mContextChunkSize = mPromptLen;
834836
mSeqSlot.reset();
835837
}
@@ -1049,7 +1051,7 @@ class GenericLlmRequest
10491051

10501052
[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
10511053
{
1052-
return mPrepopulatedPromptLen;
1054+
return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
10531055
}
10541056

10551057
void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
@@ -1066,7 +1068,10 @@ class GenericLlmRequest
10661068
"Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu", prepopulatedPromptLen,
10671069
promptLen, mRequestId);
10681070
TLLM_CHECK(prepopulatedPromptLen < promptLen);
1069-
mPrepopulatedPromptLen = prepopulatedPromptLen;
1071+
1072+
auto& prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
1073+
auto& contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
1074+
prePromptLen = prepopulatedPromptLen;
10701075

10711076
if (prepopulatedPromptLen > 0)
10721077
{
@@ -1081,7 +1086,7 @@ class GenericLlmRequest
10811086
chunkSize = flooredEndPosition - prepopulatedPromptLen;
10821087
TLLM_CHECK(chunkSize <= getContextChunkSize());
10831088
}
1084-
setContextCurrentPosition(prepopulatedPromptLen);
1089+
contextCurrentPosition = prepopulatedPromptLen;
10851090
setContextChunkSize(chunkSize);
10861091

10871092
if (!isLastContextChunk())
@@ -1522,14 +1527,15 @@ class GenericLlmRequest
15221527

15231528
void setContextCurrentPosition(SizeType32 contextCurrentPosition)
15241529
{
1525-
mContextCurrentPosition = contextCurrentPosition;
1530+
mContextCurrentPositionDraft = contextCurrentPosition;
1531+
mContextCurrentPositionTarget = contextCurrentPosition;
15261532
}
15271533

15281534
/// When chunked, the position of the current chunk is returned. Otherwise, only the beginning
15291535
/// or end of the context is returned.
15301536
[[nodiscard]] SizeType32 getContextCurrentPosition() const noexcept
15311537
{
1532-
return mContextCurrentPosition;
1538+
return mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
15331539
}
15341540

15351541
/// Return the length of the context that has not yet been processed.
@@ -1570,14 +1576,16 @@ class GenericLlmRequest
15701576
{
15711577
// The number of cached token is encountered in mContextCurrentPosition,
15721578
// so the start position of the context is mPrepopulatedPromptLen.
1573-
return mContextCurrentPosition == mPrepopulatedPromptLen;
1579+
return getContextCurrentPosition() == getPrepopulatedPromptLen();
15741580
}
15751581

15761582
/// Move the cursor forward one chunk. When not chunked, move forward to the end of the context.
15771583
void moveToNextContextChunk()
15781584
{
15791585
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
1580-
mContextCurrentPosition += getContextChunkSize();
1586+
1587+
mContextCurrentPositionDraft += getContextChunkSize();
1588+
mContextCurrentPositionTarget += getContextChunkSize();
15811589
setContextChunkSize(0);
15821590
}
15831591

@@ -1843,6 +1851,16 @@ class GenericLlmRequest
18431851
return mIsDummyRequest;
18441852
}
18451853

1854+
void setUseDraftModel(bool useDraftModel)
1855+
{
1856+
mUseDraftModel = useDraftModel;
1857+
}
1858+
1859+
[[nodiscard]] bool useDraftModel() const
1860+
{
1861+
return mUseDraftModel;
1862+
}
1863+
18461864
RequestIdType mRequestId;
18471865
SizeType32 mPromptLen;
18481866
SizeType32 mMaxNewTokens;
@@ -1885,7 +1903,8 @@ class GenericLlmRequest
18851903
// Number of tokens already in KV cache before context phase.
18861904
// A value > 0 indicates cached KV cache blocks were reused.
18871905
// Up to inputLen - 1 tokens can be reused.
1888-
SizeType32 mPrepopulatedPromptLen{0};
1906+
SizeType32 mPrepopulatedPromptLenTarget{0};
1907+
SizeType32 mPrepopulatedPromptLenDraft{0};
18891908

18901909
SizeType32 mMaxSentTokenLen;
18911910

@@ -1916,7 +1935,8 @@ class GenericLlmRequest
19161935
// The size of the context chunk must be multiple of the KV-Cache block size except the last one.
19171936
// Value `0` means Chunked-Context is disabled.
19181937
SizeType32 mContextChunkSize{0};
1919-
SizeType32 mContextCurrentPosition{0};
1938+
SizeType32 mContextCurrentPositionTarget{0};
1939+
SizeType32 mContextCurrentPositionDraft{0};
19201940

19211941
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
19221942
VecLogProbs mCumLogProbs; // [beamSize]
@@ -2017,6 +2037,8 @@ class GenericLlmRequest
20172037

20182038
bool mIsDummyRequest{false};
20192039

2040+
bool mUseDraftModel{false};
2041+
20202042
private:
20212043
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
20222044
{

cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests)
8888
continue;
8989
}
9090
auto const seqSlot = llmReq->mSeqSlot.value();
91-
if (llmReq->isContextInitState()
92-
&& llmReq->getContextCurrentPosition() == llmReq->getPrepopulatedPromptLen())
91+
if (llmReq->isContextInitState() && llmReq->isFirstContextChunk())
9392
{
9493
// The request is in the first context forward step (considering kv cache reuse).
9594
auto const& guideType = guidedDecodingParams->getGuideType();

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ void initBindings(nb::module_& m)
248248
}
249249
})
250250
.def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
251-
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics);
251+
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
252+
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
252253

253254
nb::class_<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", nb::dynamic_attr())
254255
.def(

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ void initBindings(pybind11::module_& m)
253253
}
254254
})
255255
.def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
256-
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics);
256+
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
257+
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
257258

258259
py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
259260
.def(py::init<>(

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def _create_kv_cache_manager(
314314
dtype=kv_cache_dtype,
315315
spec_config=spec_config,
316316
max_beam_width=executor_config.max_beam_width,
317+
is_draft=model_engine.is_draft_model,
317318
)
318319
elif is_nemotron_hybrid(config):
319320
if executor_config.max_beam_width > 1:
@@ -376,6 +377,7 @@ def _create_kv_cache_manager(
376377
max_num_tokens=executor_config.max_num_tokens,
377378
model_config=binding_model_config,
378379
max_beam_width=executor_config.max_beam_width,
380+
is_draft=model_engine.is_draft_model,
379381
)
380382
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
381383
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def __init__(
339339
self.py_seq_slot = seq_slot
340340
# If the request is a draft request, target_seq_slot is the sequence slot ID of its target request.
341341
self.py_target_seq_slot = target_seq_slot
342+
self.use_draft_model = is_draft
342343

343344
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
344345
# currently, keep py_stop_words_list as python list, rather than tensor.

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
except ImportError:
1818
from cuda import cudart
1919

20-
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
20+
from tensorrt_llm._torch.pyexecutor.resource_manager import (
21+
ResourceManagerType, request_context)
2122
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
2223
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
2324
is_trace_enabled, nvtx_range, trace_func)
@@ -937,11 +938,14 @@ def _executor_loop(self):
937938
self.guided_decoder.init_disagg_gen_requests(
938939
scheduled_batch)
939940
if self.drafter is not None and self.use_spec_decode:
940-
if self.guided_decoder is not None:
941-
self.guided_decoder.rollback_rejected_tokens(
942-
scheduled_batch)
943-
self.drafter.prepare_draft_tokens(
944-
scheduled_batch, self.resource_manager)
941+
with request_context(
942+
is_draft=True,
943+
scheduled_requests=scheduled_batch):
944+
if self.guided_decoder is not None:
945+
self.guided_decoder.rollback_rejected_tokens(
946+
scheduled_batch)
947+
self.drafter.prepare_draft_tokens(
948+
scheduled_batch, self.resource_manager)
945949

946950
batch_outputs = self._forward_step(scheduled_batch)
947951
self._execute_guided_decoder(scheduled_batch,

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,33 @@ def get_pp_layers(
110110
return pp_layers, total_num_layers
111111

112112

113+
def request_context(is_draft: bool, scheduled_requests: ScheduledRequests):
114+
115+
class RequestContext:
116+
117+
def __init__(self, is_draft: bool,
118+
scheduled_requests: ScheduledRequests):
119+
self.is_draft = is_draft
120+
self.scheduled_requests = scheduled_requests
121+
122+
def __enter__(self):
123+
if not self.is_draft:
124+
return
125+
126+
for req in self.scheduled_requests.all_requests():
127+
req.use_draft_model = True
128+
129+
def __exit__(self, exc_type, exc_val, exc_tb):
130+
if not self.is_draft:
131+
return
132+
133+
# Clean up the state
134+
for req in self.scheduled_requests.all_requests():
135+
req.use_draft_model = False
136+
137+
return RequestContext(is_draft, scheduled_requests)
138+
139+
113140
class KVCacheManager(BaseResourceManager):
114141

115142
def __init__(
@@ -132,6 +159,7 @@ def __init__(
132159
max_num_tokens: int = 8192,
133160
model_config: Optional[ModelConfig] = None,
134161
max_beam_width: int = 1,
162+
is_draft: bool = False,
135163
) -> None:
136164
self.mapping = mapping
137165
self.dtype = dtype
@@ -142,6 +170,7 @@ def __init__(
142170
spec_config=spec_config,
143171
layer_mask=layer_mask,
144172
)
173+
self.is_draft = is_draft
145174
self.num_local_layers = len(self.pp_layers)
146175
self.layer_offsets = {
147176
idx: offset
@@ -366,34 +395,36 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int:
366395
return need_blocks
367396

368397
def prepare_resources(self, scheduled_batch: ScheduledRequests):
369-
context_batch = scheduled_batch.context_requests
370-
generation_batch = scheduled_batch.generation_requests
371-
# allocate KV Cache
372-
for req in context_batch:
373-
req_beam_width = req.sampling_config.beam_width
374-
if 'cp_type' in self.mapping.cp_config and 'star_attention' == self.mapping.cp_config[
375-
'cp_type']:
376-
if req.ctx_iters == 0:
377-
seq_len = sum(
378-
len(ctx_block) for ctx_block in req.ctx_blocks)
379-
self.impl.add_sequence(
380-
req.py_request_id,
381-
seq_len + (len(req.query_id) if self.mapping.cp_rank
382-
== self.mapping.cp_size - 1 else 0),
383-
req_beam_width, req)
384-
else:
385-
if req.is_first_context_chunk:
386-
self.impl.add_sequence(req.py_request_id, req.prompt_len,
387-
req_beam_width, req)
388-
for _ in range(self.num_extra_kv_tokens):
389-
self.impl.add_token(req.py_request_id)
390-
for _ in range(get_draft_token_length(req)):
391-
self.impl.add_token(req.py_request_id)
392-
393-
for req in generation_batch:
394-
self.impl.add_token(req.py_request_id)
395-
for _ in range(get_draft_token_length(req)):
398+
with request_context(self.is_draft, scheduled_batch):
399+
context_batch = scheduled_batch.context_requests
400+
generation_batch = scheduled_batch.generation_requests
401+
# allocate KV Cache
402+
for req in context_batch:
403+
req_beam_width = req.sampling_config.beam_width
404+
if 'cp_type' in self.mapping.cp_config and 'star_attention' == self.mapping.cp_config[
405+
'cp_type']:
406+
if req.ctx_iters == 0:
407+
seq_len = sum(
408+
len(ctx_block) for ctx_block in req.ctx_blocks)
409+
self.impl.add_sequence(
410+
req.py_request_id,
411+
seq_len + (len(req.query_id) if self.mapping.cp_rank
412+
== self.mapping.cp_size - 1 else 0),
413+
req_beam_width, req)
414+
else:
415+
if req.is_first_context_chunk:
416+
self.impl.add_sequence(req.py_request_id,
417+
req.prompt_len, req_beam_width,
418+
req)
419+
for _ in range(self.num_extra_kv_tokens):
420+
self.impl.add_token(req.py_request_id)
421+
for _ in range(get_draft_token_length(req)):
422+
self.impl.add_token(req.py_request_id)
423+
424+
for req in generation_batch:
396425
self.impl.add_token(req.py_request_id)
426+
for _ in range(get_draft_token_length(req)):
427+
self.impl.add_token(req.py_request_id)
397428

398429
def add_dummy_requests(
399430
self,

0 commit comments

Comments
 (0)