Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 53 additions & 12 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -925,10 +925,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
// Fork at last by appending a new block directly
int32_t parent_block_idx = parent_it->second.last_block_idx;
if (!global_block_pool_[parent_block_idx].seq_length) {
// If parent ends with empty block, fork from parent's parent block
parent_block_idx = global_block_pool_[parent_block_idx].parent_idx;
}
++global_block_pool_[parent_block_idx].external_ref_cnt;
// Update child block start position and parent index
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
if (global_block_pool_[parent_block_idx].seq_length) {
// If parent is not empty, append a new block
int32_t new_parent_block_idx = GetFreeBlock();
global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx;
parent_it->second.last_block_idx = new_parent_block_idx;
}
} else {
// Locate the block to fork from and calculate in-block offset
std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
Expand Down Expand Up @@ -1038,21 +1049,51 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
auto it = seq_map_.find(seq_id);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";

Block& block = global_block_pool_[it->second.last_block_idx];
CHECK_GE(n, 0) << "The length of popping " << n << " cannot be negative.";
CHECK_LE(n, block.seq_length) << "The sequence only has length " << block.seq_length
<< " in the last block, while the length of pop is " << n
<< " which exceeds the last-block sequence length.";
CHECK_LE(n, it->second.seq_length)
<< "The sequence only has length " << it->second.seq_length
<< ", while the length of pop is " << n << " which exceeds the whole sequence length.";
int32_t block_idx = it->second.last_block_idx;
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
if (n > global_block_pool_[block_idx].seq_length) {
n -= global_block_pool_[block_idx].seq_length;
it->second.seq_length -= global_block_pool_[block_idx].seq_length;
for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
free_page_ids_.push_back(page_id);
}
free_block_idx_.push_back(block_idx);
block_idx = global_block_pool_[block_idx].parent_idx;
it->second.last_block_idx = block_idx;
continue;
}
if (n <= global_block_pool_[block_idx].seq_length) {
int64_t cur_npage = global_block_pool_[block_idx].page_ids.size();
int64_t tgt_npage =
(global_block_pool_[block_idx].seq_length - n + page_size_ - 1) / page_size_;
while (cur_npage > tgt_npage) {
free_page_ids_.push_back(global_block_pool_[block_idx].page_ids.back());
global_block_pool_[block_idx].page_ids.pop_back();
--cur_npage;
}
it->second.seq_length -= n;
global_block_pool_[block_idx].seq_length -= n;
n = 0;
break;
}
}

int64_t cur_npage = block.page_ids.size();
int64_t tgt_npage = (block.seq_length - n + page_size_ - 1) / page_size_;
while (cur_npage > tgt_npage) {
free_page_ids_.push_back(block.page_ids.back());
block.page_ids.pop_back();
--cur_npage;
if (n) {
int32_t temp_seq_id = -1 - seq_id;
CHECK(seq_map_.find(temp_seq_id) == seq_map_.end());
ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n);
CHECK(seq_map_.find(temp_seq_id) != seq_map_.end());
RemoveSequence(seq_id);
CHECK(seq_map_.find(seq_id) == seq_map_.end());
auto it = seq_map_.find(temp_seq_id);
seq_map_.insert({seq_id, Sequence(global_block_pool_, it->second.last_block_idx)});
seq_map_.erase(temp_seq_id);
}
it->second.seq_length -= n;
block.seq_length -= n;

dirty_aux_data_device_ = true;
}

Expand Down