Skip to content

Commit 283714a

Browse files
committed
fix kv cache rewind issue
Signed-off-by: Yue Weng <[email protected]>
1 parent 75ec3b1 commit 283714a

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,15 @@ def update_resources(self,
535535
scheduled_batch: ScheduledRequests,
536536
attn_metadata: "AttentionMetadata" = None,
537537
kv_cache_dtype_byte_size: float = None):
538-
self.update_kv_cache_draft_token_location(scheduled_batch,
539-
attn_metadata,
540-
kv_cache_dtype_byte_size)
541-
# rewind kv cache
542-
for request in scheduled_batch.generation_requests:
543-
if request.state != LlmRequestState.GENERATION_COMPLETE:
544-
if request.py_rewind_len > 0:
545-
self.rewind_kv_cache(request, request.py_rewind_len)
538+
if not self.is_draft:
539+
self.update_kv_cache_draft_token_location(scheduled_batch,
540+
attn_metadata,
541+
kv_cache_dtype_byte_size)
542+
# rewind kv cache
543+
for request in scheduled_batch.generation_requests:
544+
if request.state != LlmRequestState.GENERATION_COMPLETE:
545+
if request.py_rewind_len > 0:
546+
self.rewind_kv_cache(request, request.py_rewind_len)
546547

547548
# For context requests, we store the blocks for reuse.
548549
for request in scheduled_batch.context_requests:

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,9 +876,12 @@ def _process_draft_tokens_tree(
876876

877877
assert num_accepted_draft_tokens <= longest_accepted_len
878878

879-
request.py_num_accepted_draft_tokens_indices = eagle_paths[longest_match_path_idx][
880-
1:num_accepted_draft_tokens
881-
].tolist() # exclude the root node
879+
# request.py_num_accepted_draft_tokens_indices = eagle_paths[longest_match_path_idx][
880+
# 1:num_accepted_draft_tokens
881+
# ].tolist() # exclude the root node
882+
tree_node_indices = eagle_paths[longest_match_path_idx][1:num_accepted_draft_tokens]
883+
request.py_num_accepted_draft_tokens_indices = (tree_node_indices - 1).tolist()
884+
882885
return num_accepted_draft_tokens - 1
883886

884887
@torch.inference_mode()

0 commit comments

Comments
 (0)