Skip to content

Commit 18a2a25

Browse files
authored
[KVCache] Support KVCache decode from forked sequence and pop more tokens (#16995)
1 parent 3cd6673 commit 18a2a25

File tree

1 file changed

+53
-12
lines changed

1 file changed

+53
-12
lines changed

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -925,10 +925,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
925925
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
926926
// Fork at last by appending a new block directly
927927
int32_t parent_block_idx = parent_it->second.last_block_idx;
928+
if (!global_block_pool_[parent_block_idx].seq_length) {
929+
// If parent ends with empty block, fork from parent's parent block
930+
parent_block_idx = global_block_pool_[parent_block_idx].parent_idx;
931+
}
928932
++global_block_pool_[parent_block_idx].external_ref_cnt;
929933
// Update child block start position and parent index
930934
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
931935
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
936+
if (global_block_pool_[parent_block_idx].seq_length) {
937+
// If parent is not empty, append a new block
938+
int32_t new_parent_block_idx = GetFreeBlock();
939+
global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length;
940+
global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx;
941+
parent_it->second.last_block_idx = new_parent_block_idx;
942+
}
932943
} else {
933944
// Locate the block to fork from and calculate in-block offset
934945
std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
@@ -1038,21 +1049,51 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10381049
auto it = seq_map_.find(seq_id);
10391050
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
10401051

1041-
Block& block = global_block_pool_[it->second.last_block_idx];
10421052
CHECK_GE(n, 0) << "The length of popping " << n << " cannot be negative.";
1043-
CHECK_LE(n, block.seq_length) << "The sequence only has length " << block.seq_length
1044-
<< " in the last block, while the length of pop is " << n
1045-
<< " which exceeds the last-block sequence length.";
1053+
CHECK_LE(n, it->second.seq_length)
1054+
<< "The sequence only has length " << it->second.seq_length
1055+
<< ", while the length of pop is " << n << " which exceeds the whole sequence length.";
1056+
int32_t block_idx = it->second.last_block_idx;
1057+
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
1058+
if (n > global_block_pool_[block_idx].seq_length) {
1059+
n -= global_block_pool_[block_idx].seq_length;
1060+
it->second.seq_length -= global_block_pool_[block_idx].seq_length;
1061+
for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
1062+
free_page_ids_.push_back(page_id);
1063+
}
1064+
free_block_idx_.push_back(block_idx);
1065+
block_idx = global_block_pool_[block_idx].parent_idx;
1066+
it->second.last_block_idx = block_idx;
1067+
continue;
1068+
}
1069+
if (n <= global_block_pool_[block_idx].seq_length) {
1070+
int64_t cur_npage = global_block_pool_[block_idx].page_ids.size();
1071+
int64_t tgt_npage =
1072+
(global_block_pool_[block_idx].seq_length - n + page_size_ - 1) / page_size_;
1073+
while (cur_npage > tgt_npage) {
1074+
free_page_ids_.push_back(global_block_pool_[block_idx].page_ids.back());
1075+
global_block_pool_[block_idx].page_ids.pop_back();
1076+
--cur_npage;
1077+
}
1078+
it->second.seq_length -= n;
1079+
global_block_pool_[block_idx].seq_length -= n;
1080+
n = 0;
1081+
break;
1082+
}
1083+
}
10461084

1047-
int64_t cur_npage = block.page_ids.size();
1048-
int64_t tgt_npage = (block.seq_length - n + page_size_ - 1) / page_size_;
1049-
while (cur_npage > tgt_npage) {
1050-
free_page_ids_.push_back(block.page_ids.back());
1051-
block.page_ids.pop_back();
1052-
--cur_npage;
1085+
if (n) {
1086+
int32_t temp_seq_id = -1 - seq_id;
1087+
CHECK(seq_map_.find(temp_seq_id) == seq_map_.end());
1088+
ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n);
1089+
CHECK(seq_map_.find(temp_seq_id) != seq_map_.end());
1090+
RemoveSequence(seq_id);
1091+
CHECK(seq_map_.find(seq_id) == seq_map_.end());
1092+
auto it = seq_map_.find(temp_seq_id);
1093+
seq_map_.insert({seq_id, Sequence(global_block_pool_, it->second.last_block_idx)});
1094+
seq_map_.erase(temp_seq_id);
10531095
}
1054-
it->second.seq_length -= n;
1055-
block.seq_length -= n;
1096+
10561097
dirty_aux_data_device_ = true;
10571098
}
10581099

0 commit comments

Comments
 (0)