Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ class KVStateObj : public Object {
* \param parent_seq_id The parent (source) of the fork.
* \param child_seq_id The child (destination) of the fork.
* The child sequence id should not exist in cache prior to fork.
* \param fork_pos The parent position to fork, the legal forking position is within
* [0, parent_seq_length] and -1 as default for last position. And if forking position is 0,
* it equals to add a new sequence with child sequence id.
* \throws Error if the given sequence ids are not valid.
*/
virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) = 0;
virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) = 0;

/*!
* \brief Pop out the trailing `n` tokens from the KV cache for the
Expand Down
127 changes: 102 additions & 25 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
Optional<PackedFunc> f_attention_decode_end_forward_;
PackedFunc f_merge_inplace_;
PackedFunc f_split_rotary_;
PackedFunc f_copy_single_page_;
Optional<PackedFunc> f_debug_get_kv_;

/*! \brief Number of fork depth in the current round of forward. */
Expand Down Expand Up @@ -407,7 +408,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
Optional<PackedFunc> f_attention_prefill_end_forward,
Optional<PackedFunc> f_attention_decode_begin_forward,
Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc f_merge_inplace,
PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv)
PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional<PackedFunc> f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
num_qo_heads_(num_qo_heads),
Expand Down Expand Up @@ -435,6 +436,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)),
f_merge_inplace_(std::move(f_merge_inplace)),
f_split_rotary_(std::move(f_split_rotary)),
f_copy_single_page_(std::move(f_copy_single_page)),
f_debug_get_kv_(std::move(f_debug_get_kv)),
device_(device) {
pages_.reserve(num_layers);
Expand Down Expand Up @@ -527,27 +529,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
void RemoveSequence(int64_t seq_id) final {
auto it = seq_map_.find(seq_id);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
const Block& block = global_block_pool_[it->second.last_block_idx];
CHECK_EQ(block.external_ref_cnt, 0)
int32_t block_idx = it->second.last_block_idx;
CHECK_EQ(global_block_pool_[block_idx].external_ref_cnt, 0)
<< "The sequence is currently referenced by other sequence and thus cannot be removed.";

// - Decrease the external reference of the parent block.
if (block.parent_idx != -1) {
Block& parent_block = global_block_pool_[block.parent_idx];
ICHECK_GT(parent_block.external_ref_cnt, 0);
--parent_block.external_ref_cnt;
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
// - Free pages in the last block.
for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
free_page_ids_.push_back(page_id);
}
free_block_idx_.push_back(block_idx);
block_idx = global_block_pool_[block_idx].parent_idx;
}
// - Free pages in the last block.
for (int32_t page_id : block.page_ids) {
free_page_ids_.push_back(page_id);
// - Decrease the external reference of the parent block.
if (block_idx != -1) {
ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 0);
--global_block_pool_[block_idx].external_ref_cnt;
}
// - Remove the sequence from seq_map.
free_block_idx_.push_back(it->second.last_block_idx);
seq_map_.erase(it);
dirty_aux_data_device_ = true;
}

void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final {
auto parent_it = seq_map_.find(parent_seq_id);
CHECK(parent_it != seq_map_.end())
<< "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache.";
Expand All @@ -556,18 +558,89 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
CHECK_EQ(parent_it->second.sliding_window_size, -1)
<< "The parent sequence \"" << parent_seq_id
<< "\" is enabled with sliding window and thus cannot be forked.";
CHECK_GE(fork_pos, -1)
<< "The forked position should be non-negative, or -1 for last position as default.";
CHECK_LE(fork_pos, parent_it->second.seq_length)
<< "The forked position should not exceed the total length of parent sequence.";

int32_t parent_block_idx = parent_it->second.last_block_idx;
++global_block_pool_[parent_block_idx].external_ref_cnt;
// Create a child block with the parent block pointer.
int32_t child_block_idx = GetFreeBlock();
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
// Fork at last by appending a new block directly
int32_t parent_block_idx = parent_it->second.last_block_idx;
++global_block_pool_[parent_block_idx].external_ref_cnt;
// Update child block start position and parent index
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
} else {
// Locate the block to fork from and calculate in-block offset
std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
int64_t in_block_offset = fork_pos;
int32_t forked_block_idx = -1;
for (int32_t block_idx : trace) {
if (in_block_offset < global_block_pool_[block_idx].seq_length) {
forked_block_idx = block_idx;
break;
}
in_block_offset -= global_block_pool_[block_idx].seq_length;
}
int32_t in_page_offset = in_block_offset % page_size_;
int32_t moved_offset = in_block_offset - in_page_offset;
if (moved_offset == 0) {
// Forked at the first page in block
int32_t parent_block_idx = global_block_pool_[forked_block_idx].parent_idx;
if (parent_block_idx != -1) {
++global_block_pool_[parent_block_idx].external_ref_cnt;
}
// Update child block start position and parent index
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
} else {
// Forked at the second or latter page in block
int32_t parent_block_idx = GetFreeBlock();
// Insert new parent block before forked block and link child block
global_block_pool_[parent_block_idx].parent_idx =
global_block_pool_[forked_block_idx].parent_idx;
global_block_pool_[forked_block_idx].parent_idx = parent_block_idx;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
global_block_pool_[parent_block_idx].external_ref_cnt = 1;

// Move common leading pages to new parent block
auto first_page = global_block_pool_[forked_block_idx].page_ids.begin();
auto last_page =
global_block_pool_[forked_block_idx].page_ids.begin() + moved_offset / page_size_;
global_block_pool_[parent_block_idx].page_ids = {first_page, last_page};
global_block_pool_[forked_block_idx].page_ids.erase(first_page, last_page);

// Update start position per blocks
global_block_pool_[parent_block_idx].start_pos =
global_block_pool_[forked_block_idx].start_pos;
global_block_pool_[forked_block_idx].start_pos += moved_offset;

// Update in-block sequence length per blocks
global_block_pool_[parent_block_idx].seq_length = moved_offset;
global_block_pool_[forked_block_idx].seq_length -= moved_offset;
}
global_block_pool_[child_block_idx].start_pos = fork_pos - in_page_offset;
global_block_pool_[child_block_idx].seq_length = in_page_offset;

if (in_page_offset > 0) {
// Fork within a page and copy common page to child block partially
int32_t src_page_id = global_block_pool_[forked_block_idx].page_ids[0];
int32_t tgt_page_id = GetFreePage();
global_block_pool_[child_block_idx].page_ids.push_back(tgt_page_id);
CopySinglePage(src_page_id, tgt_page_id, in_page_offset);
}
}
// Create the child sequence with the child block.
seq_map_.insert({child_seq_id, Sequence(global_block_pool_, child_block_idx)});
dirty_aux_data_device_ = true;
}

void CopySinglePage(int32_t src_page_id, int32_t tgt_page_id, int64_t copy_length) {
for (int layer = 0; layer < num_layers_; ++layer) {
f_copy_single_page_(pages_[layer], src_page_id, tgt_page_id, copy_length);
}
}

void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
int32_t attn_sink_size) final {
CHECK(support_sliding_window_) << "The KV cache does not support sliding window.";
Expand Down Expand Up @@ -1390,7 +1463,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// - Reset the dirty flag to false.
dirty_aux_data_device_ = false;
}
};
}; // namespace relax_vm

TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);

Expand All @@ -1412,7 +1485,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
PackedFunc f_attention_prefill_end_forward,
PackedFunc f_attention_decode_begin_forward,
PackedFunc f_attention_decode_end_forward, PackedFunc f_merge_inplace,
PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv) {
PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
Optional<PackedFunc> f_debug_get_kv) {
CHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
Expand All @@ -1435,7 +1509,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
std::move(f_attention_prefill_ragged_end_forward),
std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward),
std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward),
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_debug_get_kv));
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page),
std::move(f_debug_get_kv));
return AttentionKVCache(std::move(n));
});

Expand All @@ -1447,7 +1522,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
PackedFunc f_attention_prefill_sliding_window,
PackedFunc f_attention_decode_sliding_window,
PackedFunc f_attention_prefill_ragged, PackedFunc f_merge_inplace,
PackedFunc f_split_rotary, Optional<PackedFunc> f_debug_get_kv) {
PackedFunc f_split_rotary, PackedFunc f_copy_single_page,
Optional<PackedFunc> f_debug_get_kv) {
CHECK_EQ(cache_config.size(), 5);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
Expand All @@ -1467,7 +1543,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged), //
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_debug_get_kv));
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page),
std::move(f_debug_get_kv));
return AttentionKVCache(std::move(n));
});

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/relax_vm/rnn_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ class RNNStateImpObj : public RNNStateObj {
dirty_aux_data_device_ = true;
}

void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id) final {
void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) final {
auto parent_it = seq_map_.find(parent_seq_id);
CHECK(parent_it != seq_map_.end()) << "The parent sequence \"" << parent_seq_id
<< "\" cannot be found in space state storage.";
Expand Down
Loading