Skip to content

Commit c3be89a

Browse files
authored
[KVCache] Support forking sequence at specific posotion (#16813)
This PR enables KVCache to fork a sequence at specific position.
1 parent 5daa303 commit c3be89a

File tree

6 files changed

+283
-56
lines changed

6 files changed

+283
-56
lines changed

src/runtime/relax_vm/kv_state.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,12 @@ class KVStateObj : public Object {
5959
* \param parent_seq_id The parent (source) of the fork.
6060
* \param child_seq_id The child (destination) of the fork.
6161
* The child sequence id should not exist in cache prior to fork.
62+
* \param fork_pos The parent position to fork, the legal forking position is within
63+
* [0, parent_seq_length] and -1 as default for last position. And if forking position is 0,
64+
* it equals to add a new sequence with child sequence id.
6265
* \throws Error if the given sequence ids are not valid.
6366
*/
64-
virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) = 0;
67+
virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) = 0;
6568

6669
/*!
6770
* \brief Pop out the trailing `n` tokens from the KV cache for the

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 102 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
373373
Optional<PackedFunc> f_attention_decode_end_forward_;
374374
PackedFunc f_merge_inplace_;
375375
PackedFunc f_split_rotary_;
376+
PackedFunc f_copy_single_page_;
376377
Optional<PackedFunc> f_debug_get_kv_;
377378

378379
/*! \brief Number of fork depth in the current round of forward. */
@@ -407,7 +408,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
407408
Optional<PackedFunc> f_attention_prefill_end_forward,
408409
Optional<PackedFunc> f_attention_decode_begin_forward,
409410
Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc f_merge_inplace,
410-
PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv)
411+
PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional<PackedFunc> f_debug_get_kv)
411412
: page_size_(page_size),
412413
num_layers_(num_layers),
413414
num_qo_heads_(num_qo_heads),
@@ -435,6 +436,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
435436
f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)),
436437
f_merge_inplace_(std::move(f_merge_inplace)),
437438
f_split_rotary_(std::move(f_split_rotary)),
439+
f_copy_single_page_(std::move(f_copy_single_page)),
438440
f_debug_get_kv_(std::move(f_debug_get_kv)),
439441
device_(device) {
440442
pages_.reserve(num_layers);
@@ -527,27 +529,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
527529
void RemoveSequence(int64_t seq_id) final {
528530
auto it = seq_map_.find(seq_id);
529531
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
530-
const Block& block = global_block_pool_[it->second.last_block_idx];
531-
CHECK_EQ(block.external_ref_cnt, 0)
532+
int32_t block_idx = it->second.last_block_idx;
533+
CHECK_EQ(global_block_pool_[block_idx].external_ref_cnt, 0)
532534
<< "The sequence is currently referenced by other sequence and thus cannot be removed.";
533-
534-
// - Decrease the external reference of the parent block.
535-
if (block.parent_idx != -1) {
536-
Block& parent_block = global_block_pool_[block.parent_idx];
537-
ICHECK_GT(parent_block.external_ref_cnt, 0);
538-
--parent_block.external_ref_cnt;
535+
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
536+
// - Free pages in the last block.
537+
for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
538+
free_page_ids_.push_back(page_id);
539+
}
540+
free_block_idx_.push_back(block_idx);
541+
block_idx = global_block_pool_[block_idx].parent_idx;
539542
}
540-
// - Free pages in the last block.
541-
for (int32_t page_id : block.page_ids) {
542-
free_page_ids_.push_back(page_id);
543+
// - Decrease the external reference of the parent block.
544+
if (block_idx != -1) {
545+
ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 0);
546+
--global_block_pool_[block_idx].external_ref_cnt;
543547
}
544-
// - Remove the sequence from seq_map.
545-
free_block_idx_.push_back(it->second.last_block_idx);
546548
seq_map_.erase(it);
547549
dirty_aux_data_device_ = true;
548550
}
549551

550-
void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
552+
void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final {
551553
auto parent_it = seq_map_.find(parent_seq_id);
552554
CHECK(parent_it != seq_map_.end())
553555
<< "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache.";
@@ -556,18 +558,89 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
556558
CHECK_EQ(parent_it->second.sliding_window_size, -1)
557559
<< "The parent sequence \"" << parent_seq_id
558560
<< "\" is enabled with sliding window and thus cannot be forked.";
561+
CHECK_GE(fork_pos, -1)
562+
<< "The forked position should be non-negative, or -1 for last position as default.";
563+
CHECK_LE(fork_pos, parent_it->second.seq_length)
564+
<< "The forked position should not exceed the total length of parent sequence.";
559565

560-
int32_t parent_block_idx = parent_it->second.last_block_idx;
561-
++global_block_pool_[parent_block_idx].external_ref_cnt;
562-
// Create a child block with the parent block pointer.
563566
int32_t child_block_idx = GetFreeBlock();
564-
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
565-
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
567+
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
568+
// Fork at last by appending a new block directly
569+
int32_t parent_block_idx = parent_it->second.last_block_idx;
570+
++global_block_pool_[parent_block_idx].external_ref_cnt;
571+
// Update child block start position and parent index
572+
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
573+
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
574+
} else {
575+
// Locate the block to fork from and calculate in-block offset
576+
std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
577+
int64_t in_block_offset = fork_pos;
578+
int32_t forked_block_idx = -1;
579+
for (int32_t block_idx : trace) {
580+
if (in_block_offset < global_block_pool_[block_idx].seq_length) {
581+
forked_block_idx = block_idx;
582+
break;
583+
}
584+
in_block_offset -= global_block_pool_[block_idx].seq_length;
585+
}
586+
int32_t in_page_offset = in_block_offset % page_size_;
587+
int32_t moved_offset = in_block_offset - in_page_offset;
588+
if (moved_offset == 0) {
589+
// Forked at the first page in block
590+
int32_t parent_block_idx = global_block_pool_[forked_block_idx].parent_idx;
591+
if (parent_block_idx != -1) {
592+
++global_block_pool_[parent_block_idx].external_ref_cnt;
593+
}
594+
// Update child block start position and parent index
595+
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
596+
} else {
597+
// Forked at the second or latter page in block
598+
int32_t parent_block_idx = GetFreeBlock();
599+
// Insert new parent block before forked block and link child block
600+
global_block_pool_[parent_block_idx].parent_idx =
601+
global_block_pool_[forked_block_idx].parent_idx;
602+
global_block_pool_[forked_block_idx].parent_idx = parent_block_idx;
603+
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
604+
global_block_pool_[parent_block_idx].external_ref_cnt = 1;
605+
606+
// Move common leading pages to new parent block
607+
auto first_page = global_block_pool_[forked_block_idx].page_ids.begin();
608+
auto last_page =
609+
global_block_pool_[forked_block_idx].page_ids.begin() + moved_offset / page_size_;
610+
global_block_pool_[parent_block_idx].page_ids = {first_page, last_page};
611+
global_block_pool_[forked_block_idx].page_ids.erase(first_page, last_page);
612+
613+
// Update start position per blocks
614+
global_block_pool_[parent_block_idx].start_pos =
615+
global_block_pool_[forked_block_idx].start_pos;
616+
global_block_pool_[forked_block_idx].start_pos += moved_offset;
617+
618+
// Update in-block sequence length per blocks
619+
global_block_pool_[parent_block_idx].seq_length = moved_offset;
620+
global_block_pool_[forked_block_idx].seq_length -= moved_offset;
621+
}
622+
global_block_pool_[child_block_idx].start_pos = fork_pos - in_page_offset;
623+
global_block_pool_[child_block_idx].seq_length = in_page_offset;
624+
625+
if (in_page_offset > 0) {
626+
// Fork within a page and copy common page to child block partially
627+
int32_t src_page_id = global_block_pool_[forked_block_idx].page_ids[0];
628+
int32_t tgt_page_id = GetFreePage();
629+
global_block_pool_[child_block_idx].page_ids.push_back(tgt_page_id);
630+
CopySinglePage(src_page_id, tgt_page_id, in_page_offset);
631+
}
632+
}
566633
// Create the child sequence with the child block.
567634
seq_map_.insert({child_seq_id, Sequence(global_block_pool_, child_block_idx)});
568635
dirty_aux_data_device_ = true;
569636
}
570637

638+
void CopySinglePage(int32_t src_page_id, int32_t tgt_page_id, int64_t copy_length) {
639+
for (int layer = 0; layer < num_layers_; ++layer) {
640+
f_copy_single_page_(pages_[layer], src_page_id, tgt_page_id, copy_length);
641+
}
642+
}
643+
571644
void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
572645
int32_t attn_sink_size) final {
573646
CHECK(support_sliding_window_) << "The KV cache does not support sliding window.";
@@ -1390,7 +1463,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
13901463
// - Reset the dirty flag to false.
13911464
dirty_aux_data_device_ = false;
13921465
}
1393-
};
1466+
}; // namespace relax_vm
13941467

13951468
TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
13961469

@@ -1412,7 +1485,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
14121485
PackedFunc f_attention_prefill_end_forward,
14131486
PackedFunc f_attention_decode_begin_forward,
14141487
PackedFunc f_attention_decode_end_forward, PackedFunc f_merge_inplace,
1415-
PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv) {
1488+
PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
1489+
Optional<PackedFunc> f_debug_get_kv) {
14161490
CHECK_EQ(cache_config.size(), 5);
14171491
int64_t reserved_num_seqs = cache_config[0];
14181492
int64_t total_token_capacity = cache_config[1];
@@ -1435,7 +1509,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
14351509
std::move(f_attention_prefill_ragged_end_forward),
14361510
std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward),
14371511
std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward),
1438-
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_debug_get_kv));
1512+
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page),
1513+
std::move(f_debug_get_kv));
14391514
return AttentionKVCache(std::move(n));
14401515
});
14411516

@@ -1447,7 +1522,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
14471522
PackedFunc f_attention_prefill_sliding_window,
14481523
PackedFunc f_attention_decode_sliding_window,
14491524
PackedFunc f_attention_prefill_ragged, PackedFunc f_merge_inplace,
1450-
PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv) {
1525+
PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
1526+
Optional<PackedFunc> f_debug_get_kv) {
14511527
CHECK_EQ(cache_config.size(), 5);
14521528
int64_t reserved_num_seqs = cache_config[0];
14531529
int64_t total_token_capacity = cache_config[1];
@@ -1467,7 +1543,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
14671543
std::move(f_attention_prefill_sliding_window),
14681544
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), //
14691545
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
1470-
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_debug_get_kv));
1546+
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page),
1547+
std::move(f_debug_get_kv));
14711548
return AttentionKVCache(std::move(n));
14721549
});
14731550

src/runtime/relax_vm/rnn_state.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ class RNNStateImpObj : public RNNStateObj {
319319
dirty_aux_data_device_ = true;
320320
}
321321

322-
void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
322+
void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final {
323323
auto parent_it = seq_map_.find(parent_seq_id);
324324
CHECK(parent_it != seq_map_.end()) << "The parent sequence \"" << parent_seq_id
325325
<< "\" cannot be found in space state storage.";

0 commit comments

Comments
 (0)