@@ -925,10 +925,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
925925 if (fork_pos == -1 || fork_pos == parent_it->second .seq_length ) {
926926 // Fork at last by appending a new block directly
927927 int32_t parent_block_idx = parent_it->second .last_block_idx ;
928+ if (!global_block_pool_[parent_block_idx].seq_length ) {
929+ // If parent ends with empty block, fork from parent's parent block
930+ parent_block_idx = global_block_pool_[parent_block_idx].parent_idx ;
931+ }
928932 ++global_block_pool_[parent_block_idx].external_ref_cnt ;
929933 // Update child block start position and parent index
930934 global_block_pool_[child_block_idx].start_pos = parent_it->second .seq_length ;
931935 global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
936+ if (global_block_pool_[parent_block_idx].seq_length ) {
937+ // If parent is not empty, append a new block
938+ int32_t new_parent_block_idx = GetFreeBlock ();
939+ global_block_pool_[new_parent_block_idx].start_pos = parent_it->second .seq_length ;
940+ global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx;
941+ parent_it->second .last_block_idx = new_parent_block_idx;
942+ }
932943 } else {
933944 // Locate the block to fork from and calculate in-block offset
934945 std::vector<int32_t > trace = parent_it->second .GetBlockTrace (global_block_pool_);
@@ -1038,21 +1049,51 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10381049 auto it = seq_map_.find (seq_id);
10391050 CHECK (it != seq_map_.end ()) << " The sequence \" " << seq_id << " \" cannot be found in KV cache." ;
10401051
1041- Block& block = global_block_pool_[it->second .last_block_idx ];
10421052 CHECK_GE (n, 0 ) << " The length of popping " << n << " cannot be negative." ;
1043- CHECK_LE (n, block.seq_length ) << " The sequence only has length " << block.seq_length
1044- << " in the last block, while the length of pop is " << n
1045- << " which exceeds the last-block sequence length." ;
1053+ CHECK_LE (n, it->second .seq_length )
1054+ << " The sequence only has length " << it->second .seq_length
1055+ << " , while the length of pop is " << n << " which exceeds the whole sequence length." ;
1056+ int32_t block_idx = it->second .last_block_idx ;
1057+ while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0 ) {
1058+ if (n > global_block_pool_[block_idx].seq_length ) {
1059+ n -= global_block_pool_[block_idx].seq_length ;
1060+ it->second .seq_length -= global_block_pool_[block_idx].seq_length ;
1061+ for (int32_t page_id : global_block_pool_[block_idx].page_ids ) {
1062+ free_page_ids_.push_back (page_id);
1063+ }
1064+ free_block_idx_.push_back (block_idx);
1065+ block_idx = global_block_pool_[block_idx].parent_idx ;
1066+ it->second .last_block_idx = block_idx;
1067+ continue ;
1068+ }
1069+ if (n <= global_block_pool_[block_idx].seq_length ) {
1070+ int64_t cur_npage = global_block_pool_[block_idx].page_ids .size ();
1071+ int64_t tgt_npage =
1072+ (global_block_pool_[block_idx].seq_length - n + page_size_ - 1 ) / page_size_;
1073+ while (cur_npage > tgt_npage) {
1074+ free_page_ids_.push_back (global_block_pool_[block_idx].page_ids .back ());
1075+ global_block_pool_[block_idx].page_ids .pop_back ();
1076+ --cur_npage;
1077+ }
1078+ it->second .seq_length -= n;
1079+ global_block_pool_[block_idx].seq_length -= n;
1080+ n = 0 ;
1081+ break ;
1082+ }
1083+ }
10461084
1047- int64_t cur_npage = block.page_ids .size ();
1048- int64_t tgt_npage = (block.seq_length - n + page_size_ - 1 ) / page_size_;
1049- while (cur_npage > tgt_npage) {
1050- free_page_ids_.push_back (block.page_ids .back ());
1051- block.page_ids .pop_back ();
1052- --cur_npage;
1085+ if (n) {
1086+ int32_t temp_seq_id = -1 - seq_id;
1087+ CHECK (seq_map_.find (temp_seq_id) == seq_map_.end ());
1088+ ForkSequence (seq_id, temp_seq_id, it->second .seq_length - n);
1089+ CHECK (seq_map_.find (temp_seq_id) != seq_map_.end ());
1090+ RemoveSequence (seq_id);
1091+ CHECK (seq_map_.find (seq_id) == seq_map_.end ());
1092+ auto it = seq_map_.find (temp_seq_id);
1093+ seq_map_.insert ({seq_id, Sequence (global_block_pool_, it->second .last_block_idx )});
1094+ seq_map_.erase (temp_seq_id);
10531095 }
1054- it->second .seq_length -= n;
1055- block.seq_length -= n;
1096+
10561097 dirty_aux_data_device_ = true ;
10571098 }
10581099
0 commit comments