@@ -373,6 +373,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
373373 Optional<PackedFunc> f_attention_decode_end_forward_;
374374 PackedFunc f_merge_inplace_;
375375 PackedFunc f_split_rotary_;
376+ PackedFunc f_copy_single_page_;
376377 Optional<PackedFunc> f_debug_get_kv_;
377378
378379 /* ! \brief Number of fork depth in the current round of forward. */
@@ -407,7 +408,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
407408 Optional<PackedFunc> f_attention_prefill_end_forward,
408409 Optional<PackedFunc> f_attention_decode_begin_forward,
409410 Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc f_merge_inplace,
410- PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv)
411+ PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional<PackedFunc> f_debug_get_kv)
411412 : page_size_(page_size),
412413 num_layers_(num_layers),
413414 num_qo_heads_(num_qo_heads),
@@ -435,6 +436,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
435436 f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)),
436437 f_merge_inplace_(std::move(f_merge_inplace)),
437438 f_split_rotary_(std::move(f_split_rotary)),
439+ f_copy_single_page_(std::move(f_copy_single_page)),
438440 f_debug_get_kv_(std::move(f_debug_get_kv)),
439441 device_(device) {
440442 pages_.reserve (num_layers);
@@ -527,27 +529,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
527529 void RemoveSequence (int64_t seq_id) final {
528530 auto it = seq_map_.find (seq_id);
529531 CHECK (it != seq_map_.end ()) << " The sequence \" " << seq_id << " \" cannot be found in KV cache." ;
530- const Block& block = global_block_pool_[ it->second .last_block_idx ] ;
531- CHECK_EQ (block .external_ref_cnt , 0 )
532+ int32_t block_idx = it->second .last_block_idx ;
533+ CHECK_EQ (global_block_pool_[block_idx] .external_ref_cnt , 0 )
532534 << " The sequence is currently referenced by other sequence and thus cannot be removed." ;
533-
534- // - Decrease the external reference of the parent block.
535- if (block.parent_idx != -1 ) {
536- Block& parent_block = global_block_pool_[block.parent_idx ];
537- ICHECK_GT (parent_block.external_ref_cnt , 0 );
538- --parent_block.external_ref_cnt ;
535+ while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0 ) {
536+ // - Free pages in the last block.
537+ for (int32_t page_id : global_block_pool_[block_idx].page_ids ) {
538+ free_page_ids_.push_back (page_id);
539+ }
540+ free_block_idx_.push_back (block_idx);
541+ block_idx = global_block_pool_[block_idx].parent_idx ;
539542 }
540- // - Free pages in the last block.
541- for (int32_t page_id : block.page_ids ) {
542- free_page_ids_.push_back (page_id);
543+ // - Decrease the external reference of the parent block.
544+ if (block_idx != -1 ) {
545+ ICHECK_GT (global_block_pool_[block_idx].external_ref_cnt , 0 );
546+ --global_block_pool_[block_idx].external_ref_cnt ;
543547 }
544- // - Remove the sequence from seq_map.
545- free_block_idx_.push_back (it->second .last_block_idx );
546548 seq_map_.erase (it);
547549 dirty_aux_data_device_ = true ;
548550 }
549551
550- void ForkSequence (int64_t parent_seq_id, int64_t child_seq_id) final {
552+ void ForkSequence (int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = - 1 ) final {
551553 auto parent_it = seq_map_.find (parent_seq_id);
552554 CHECK (parent_it != seq_map_.end ())
553555 << " The parent sequence \" " << parent_seq_id << " \" cannot be found in KV cache." ;
@@ -556,18 +558,89 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
556558 CHECK_EQ (parent_it->second .sliding_window_size , -1 )
557559 << " The parent sequence \" " << parent_seq_id
558560 << " \" is enabled with sliding window and thus cannot be forked." ;
561+ CHECK_GE (fork_pos, -1 )
562+ << " The forked position should be non-negative, or -1 for last position as default." ;
563+ CHECK_LE (fork_pos, parent_it->second .seq_length )
564+ << " The forked position should not exceed the total length of parent sequence." ;
559565
560- int32_t parent_block_idx = parent_it->second .last_block_idx ;
561- ++global_block_pool_[parent_block_idx].external_ref_cnt ;
562- // Create a child block with the parent block pointer.
563566 int32_t child_block_idx = GetFreeBlock ();
564- global_block_pool_[child_block_idx].start_pos = parent_it->second .seq_length ;
565- global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
567+ if (fork_pos == -1 || fork_pos == parent_it->second .seq_length ) {
568+ // Fork at last by appending a new block directly
569+ int32_t parent_block_idx = parent_it->second .last_block_idx ;
570+ ++global_block_pool_[parent_block_idx].external_ref_cnt ;
571+ // Update child block start position and parent index
572+ global_block_pool_[child_block_idx].start_pos = parent_it->second .seq_length ;
573+ global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
574+ } else {
575+ // Locate the block to fork from and calculate in-block offset
576+ std::vector<int32_t > trace = parent_it->second .GetBlockTrace (global_block_pool_);
577+ int64_t in_block_offset = fork_pos;
578+ int32_t forked_block_idx = -1 ;
579+ for (int32_t block_idx : trace) {
580+ if (in_block_offset < global_block_pool_[block_idx].seq_length ) {
581+ forked_block_idx = block_idx;
582+ break ;
583+ }
584+ in_block_offset -= global_block_pool_[block_idx].seq_length ;
585+ }
586+ int32_t in_page_offset = in_block_offset % page_size_;
587+ int32_t moved_offset = in_block_offset - in_page_offset;
588+ if (moved_offset == 0 ) {
589+ // Forked at the first page in block
590+ int32_t parent_block_idx = global_block_pool_[forked_block_idx].parent_idx ;
591+ if (parent_block_idx != -1 ) {
592+ ++global_block_pool_[parent_block_idx].external_ref_cnt ;
593+ }
594+ // Update child block start position and parent index
595+ global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
596+ } else {
597+ // Forked at the second or latter page in block
598+ int32_t parent_block_idx = GetFreeBlock ();
599+ // Insert new parent block before forked block and link child block
600+ global_block_pool_[parent_block_idx].parent_idx =
601+ global_block_pool_[forked_block_idx].parent_idx ;
602+ global_block_pool_[forked_block_idx].parent_idx = parent_block_idx;
603+ global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
604+ global_block_pool_[parent_block_idx].external_ref_cnt = 1 ;
605+
606+ // Move common leading pages to new parent block
607+ auto first_page = global_block_pool_[forked_block_idx].page_ids .begin ();
608+ auto last_page =
609+ global_block_pool_[forked_block_idx].page_ids .begin () + moved_offset / page_size_;
610+ global_block_pool_[parent_block_idx].page_ids = {first_page, last_page};
611+ global_block_pool_[forked_block_idx].page_ids .erase (first_page, last_page);
612+
613+ // Update start position per blocks
614+ global_block_pool_[parent_block_idx].start_pos =
615+ global_block_pool_[forked_block_idx].start_pos ;
616+ global_block_pool_[forked_block_idx].start_pos += moved_offset;
617+
618+ // Update in-block sequence length per blocks
619+ global_block_pool_[parent_block_idx].seq_length = moved_offset;
620+ global_block_pool_[forked_block_idx].seq_length -= moved_offset;
621+ }
622+ global_block_pool_[child_block_idx].start_pos = fork_pos - in_page_offset;
623+ global_block_pool_[child_block_idx].seq_length = in_page_offset;
624+
625+ if (in_page_offset > 0 ) {
626+ // Fork within a page and copy common page to child block partially
627+ int32_t src_page_id = global_block_pool_[forked_block_idx].page_ids [0 ];
628+ int32_t tgt_page_id = GetFreePage ();
629+ global_block_pool_[child_block_idx].page_ids .push_back (tgt_page_id);
630+ CopySinglePage (src_page_id, tgt_page_id, in_page_offset);
631+ }
632+ }
566633 // Create the child sequence with the child block.
567634 seq_map_.insert ({child_seq_id, Sequence (global_block_pool_, child_block_idx)});
568635 dirty_aux_data_device_ = true ;
569636 }
570637
638+ void CopySinglePage (int32_t src_page_id, int32_t tgt_page_id, int64_t copy_length) {
639+ for (int layer = 0 ; layer < num_layers_; ++layer) {
640+ f_copy_single_page_ (pages_[layer], src_page_id, tgt_page_id, copy_length);
641+ }
642+ }
643+
571644 void EnableSlidingWindowForSeq (int64_t seq_id, int32_t sliding_window_size,
572645 int32_t attn_sink_size) final {
573646 CHECK (support_sliding_window_) << " The KV cache does not support sliding window." ;
@@ -1390,7 +1463,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
13901463 // - Reset the dirty flag to false.
13911464 dirty_aux_data_device_ = false ;
13921465 }
1393- };
1466+ }; // namespace relax_vm
13941467
13951468TVM_REGISTER_OBJECT_TYPE (PagedAttentionKVCacheObj);
13961469
@@ -1412,7 +1485,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
14121485 PackedFunc f_attention_prefill_end_forward,
14131486 PackedFunc f_attention_decode_begin_forward,
14141487 PackedFunc f_attention_decode_end_forward, PackedFunc f_merge_inplace,
1415- PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv) {
1488+ PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
1489+ Optional<PackedFunc> f_debug_get_kv) {
14161490 CHECK_EQ (cache_config.size (), 5 );
14171491 int64_t reserved_num_seqs = cache_config[0 ];
14181492 int64_t total_token_capacity = cache_config[1 ];
@@ -1435,7 +1509,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
14351509 std::move (f_attention_prefill_ragged_end_forward),
14361510 std::move (f_attention_prefill_begin_forward), std::move (f_attention_prefill_end_forward),
14371511 std::move (f_attention_decode_begin_forward), std::move (f_attention_decode_end_forward),
1438- std::move (f_merge_inplace), std::move (f_split_rotary), std::move (f_debug_get_kv));
1512+ std::move (f_merge_inplace), std::move (f_split_rotary), std::move (f_copy_single_page),
1513+ std::move (f_debug_get_kv));
14391514 return AttentionKVCache (std::move (n));
14401515 });
14411516
@@ -1447,7 +1522,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
14471522 PackedFunc f_attention_prefill_sliding_window,
14481523 PackedFunc f_attention_decode_sliding_window,
14491524 PackedFunc f_attention_prefill_ragged, PackedFunc f_merge_inplace,
1450- PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv) {
1525+ PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
1526+ Optional<PackedFunc> f_debug_get_kv) {
14511527 CHECK_EQ (cache_config.size (), 5 );
14521528 int64_t reserved_num_seqs = cache_config[0 ];
14531529 int64_t total_token_capacity = cache_config[1 ];
@@ -1467,7 +1543,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
14671543 std::move (f_attention_prefill_sliding_window),
14681544 std::move (f_attention_decode_sliding_window), std::move (f_attention_prefill_ragged), //
14691545 NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
1470- std::move (f_merge_inplace), std::move (f_split_rotary), std::move (f_debug_get_kv));
1546+ std::move (f_merge_inplace), std::move (f_split_rotary), std::move (f_copy_single_page),
1547+ std::move (f_debug_get_kv));
14711548 return AttentionKVCache (std::move (n));
14721549 });
14731550
0 commit comments