Skip to content

Commit 3f8a04d

Browse files
committed
upd
1 parent e85a14f commit 3f8a04d

File tree

2 files changed

+142
-30
lines changed

2 files changed

+142
-30
lines changed

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 82 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,6 @@ struct Sequence {
178178
}
179179
block_ptr = block.parent_idx;
180180
}
181-
CHECK_LE(depth, kPagedKVCacheMaxBlockDepth)
182-
<< "Paged KV cache supports one sequence to reuse " << kPagedKVCacheMaxBlockDepth
183-
<< " prefixes (the fork depth) at most. However, the given sequence has fork depth "
184-
<< depth;
185181
}
186182

187183
std::vector<int32_t> GetBlockTrace(const std::vector<Block>& global_block_pool) const {
@@ -1490,19 +1486,29 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
14901486
is_chain_ = true;
14911487
}
14921488

1493-
std::vector<std::vector<int32_t>> block_ids_on_depths = GetBlockIdsOnDepth(sequences);
1494-
num_depths_ = block_ids_on_depths.size();
1489+
auto [block_ids_on_depths, trailing_blocks] = GetBlockIdsOnDepth(sequences);
1490+
num_depths_ =
1491+
std::min(static_cast<int>(block_ids_on_depths.size()), kPagedKVCacheMaxBlockDepth);
14951492
ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth);
14961493

14971494
std::vector<std::vector<std::pair<int32_t, int32_t>>> chunked_block_ids_arr;
14981495
chunked_block_ids_arr.reserve(num_depths_);
14991496
use_decode_kernel_.clear();
15001497
for (int d = 0; d < num_depths_; ++d) {
1501-
auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds(block_ids_on_depths[d]);
1498+
// We force the blocks at maximum depth not to coalesce, so that it can be concatenated with
1499+
// trailing exceeding blocks.
1500+
auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds(
1501+
block_ids_on_depths[d], /*enable_coalesce=*/d != kPagedKVCacheMaxBlockDepth - 1);
15021502
chunked_block_ids_arr.push_back(chunked_block_ids);
15031503
use_decode_kernel_.push_back(use_decode_kernel);
15041504
}
15051505

1506+
if (num_depths_ == kPagedKVCacheMaxBlockDepth) {
1507+
// Since we force the blocks at maximum depth not to coalesce, the output blocks at maximum
1508+
// depth must have the same size as current batch.
1509+
CHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_);
1510+
}
1511+
15061512
append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0];
15071513
if (append_before_attn_) {
15081514
// Right now we use different kernels when depth is 1 or not 1.
@@ -1530,7 +1536,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
15301536
k_rope_pos_offset_h.clear();
15311537
qo_indptr_h.push_back(0);
15321538
page_indptr_h.push_back(0);
1533-
for (const auto& [block_id, chunk_append_length] : chunked_block_ids_arr[d]) {
1539+
for (int i = 0; i < static_cast<int>(chunked_block_ids_arr[d].size()); ++i) {
1540+
const auto& [block_id, chunk_append_length] = chunked_block_ids_arr[d][i];
15341541
qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length);
15351542
if (block_id == -1) {
15361543
page_indptr_h.push_back(page_indptr_h.back());
@@ -1539,19 +1546,53 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
15391546
sink_size_h.push_back(0);
15401547
k_rope_pos_offset_h.push_back(0);
15411548
} else {
1542-
const Block& block = global_block_pool_[block_id];
1543-
page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size());
1544-
for (int32_t page_id : block.page_ids) {
1545-
page_indices_h.push_back(page_id);
1549+
if (d < kPagedKVCacheMaxBlockDepth - 1) {
1550+
// Blocks not at maximum depth
1551+
const Block& block = global_block_pool_[block_id];
1552+
page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size());
1553+
for (int32_t page_id : block.page_ids) {
1554+
page_indices_h.push_back(page_id);
1555+
}
1556+
last_page_len_h.push_back(
1557+
block.seq_length == 0
1558+
? 0
1559+
: (block.seq_length - block.sink_length + block.sliding_window_offset - 1) %
1560+
page_size_ +
1561+
1);
1562+
sliding_window_offset_h.push_back(block.sliding_window_offset);
1563+
sink_size_h.push_back(block.sink_length);
1564+
k_rope_pos_offset_h.push_back(block.start_pos);
1565+
} else {
1566+
// Blocks at maximum depth
1567+
const Block& block = global_block_pool_[block_id];
1568+
int32_t num_pages = static_cast<int32_t>(block.page_ids.size());
1569+
int32_t total_seq_length = static_cast<int32_t>(block.seq_length);
1570+
int32_t last_block_id = block_id;
1571+
for (int32_t page_id : block.page_ids) {
1572+
page_indices_h.push_back(page_id);
1573+
}
1574+
for (int32_t id : trailing_blocks[i]) {
1575+
// Collect trailing blocks if available
1576+
const Block& block = global_block_pool_[id];
1577+
for (int32_t page_id : block.page_ids) {
1578+
page_indices_h.push_back(page_id);
1579+
}
1580+
num_pages += block.page_ids.size();
1581+
total_seq_length += block.seq_length;
1582+
last_block_id = id;
1583+
}
1584+
page_indptr_h.push_back(page_indptr_h.back() + num_pages);
1585+
const Block& last_block = global_block_pool_[last_block_id];
1586+
last_page_len_h.push_back(total_seq_length == 0
1587+
? 0
1588+
: (total_seq_length - last_block.sink_length +
1589+
last_block.sliding_window_offset - 1) %
1590+
page_size_ +
1591+
1);
1592+
sliding_window_offset_h.push_back(last_block.sliding_window_offset);
1593+
sink_size_h.push_back(last_block.sink_length);
1594+
k_rope_pos_offset_h.push_back(block.start_pos);
15461595
}
1547-
last_page_len_h.push_back(block.seq_length == 0 ? 0
1548-
: (block.seq_length - block.sink_length +
1549-
block.sliding_window_offset - 1) %
1550-
page_size_ +
1551-
1);
1552-
sliding_window_offset_h.push_back(block.sliding_window_offset);
1553-
sink_size_h.push_back(block.sink_length);
1554-
k_rope_pos_offset_h.push_back(block.start_pos);
15551596
}
15561597
}
15571598
}
@@ -2035,22 +2076,34 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
20352076
/*!
20362077
* \brief For the given list of sequences, check the block trace of
20372078
* each sequence, and return the blocks ids used by the sequences
2038-
* on each depth.
2079+
* on each depth. And if the depth is larger than the kPagedKVCacheMaxBlockDepth,
2080+
* the exceeding blocks will concatenate and output separately.
20392081
* More precisely, the inner returned vector contains the block ids
20402082
* used by the sequences on a certain depth (or "-1" if a sequence
20412083
* has fewer depth). The outer returned vector contains the inner
20422084
* vectors from the lowest depth to the highest depth.
20432085
*/
2044-
std::vector<std::vector<int32_t>> GetBlockIdsOnDepth(
2045-
const std::vector<Sequence*>& sequences) const {
2086+
std::pair<std::vector<std::vector<int32_t>>, std::vector<std::vector<int32_t>>>
2087+
GetBlockIdsOnDepth(const std::vector<Sequence*>& sequences) const {
20462088
// - Get the trace of each sequence.
20472089
int64_t num_depths = 0;
20482090
std::vector<std::vector<int32_t>> seq_block_traces;
2091+
std::vector<std::vector<int32_t>> trailing_block_traces;
20492092
seq_block_traces.reserve(cur_batch_size_);
2093+
trailing_block_traces.reserve(cur_batch_size_);
20502094
for (int i = 0; i < cur_batch_size_; ++i) {
20512095
std::vector<int32_t> trace = sequences[i]->GetBlockTrace(global_block_pool_);
2052-
num_depths = std::max(num_depths, static_cast<int64_t>(trace.size()));
2053-
seq_block_traces.push_back(std::move(trace));
2096+
if (static_cast<int>(trace.size()) <= kPagedKVCacheMaxBlockDepth) {
2097+
seq_block_traces.push_back(std::vector<int32_t>(trace.begin(), trace.end()));
2098+
trailing_block_traces.push_back({});
2099+
num_depths = std::max(num_depths, static_cast<int64_t>(trace.size()));
2100+
} else {
2101+
seq_block_traces.push_back(
2102+
std::vector<int32_t>(trace.begin(), trace.begin() + kPagedKVCacheMaxBlockDepth));
2103+
trailing_block_traces.push_back(
2104+
std::vector<int32_t>(trace.begin() + kPagedKVCacheMaxBlockDepth, trace.end()));
2105+
num_depths = std::max(num_depths, static_cast<int64_t>(kPagedKVCacheMaxBlockDepth));
2106+
}
20542107
}
20552108

20562109
// "Transpose" the traces, yielding the block ids used on each depth.
@@ -2065,7 +2118,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
20652118
}
20662119
block_ids_on_depths.push_back(std::move(block_ids));
20672120
}
2068-
return block_ids_on_depths;
2121+
return {block_ids_on_depths, trailing_block_traces};
20692122
}
20702123

20712124
/*!
@@ -2081,7 +2134,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
20812134
* input blocks.
20822135
*/
20832136
std::pair<std::vector<std::pair<int32_t, int32_t>>, bool> GetChunkedBlockIds(
2084-
const std::vector<int32_t>& block_ids) const {
2137+
const std::vector<int32_t>& block_ids, bool enable_coalesce = true) const {
20852138
std::vector<std::pair<int32_t, int32_t>> uncoalesced_block_ids;
20862139
std::vector<std::pair<int32_t, int32_t>> coalesced_block_ids;
20872140

@@ -2115,8 +2168,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
21152168
double coalesce_ratio = 1.0 * page_counter_uncoalesced / page_counter_coalesced;
21162169
// Do not coalesce and use batch decode kernel when coalesce ratio is small.
21172170
bool use_decode_kernel = is_decode_request_ && coalesce_ratio < 1.1;
2118-
2119-
return {use_decode_kernel ? uncoalesced_block_ids : coalesced_block_ids, use_decode_kernel};
2171+
return {use_decode_kernel || !enable_coalesce ? uncoalesced_block_ids : coalesced_block_ids,
2172+
use_decode_kernel};
21202173
}
21212174

21222175
/*! \brief Invoke the "begin forward" functions of underlying kernels. */

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,13 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
581581
apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45), ((12, 0, 15), 14)], cached_k, cached_v)
582582
apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19), ((14, 0, 17), 19)], cached_k, cached_v)
583583
apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8), ((16, 5, 80), 10)], cached_k, cached_v)
584-
apply_attention(kv_cache, rope_mode, [((17, 5, 75), 11), ((18, 5, 76), 45), ((19, 5, 77), 14)], cached_k, cached_v)
584+
apply_attention(
585+
kv_cache,
586+
rope_mode,
587+
[((17, 5, 75), 11), ((18, 5, 76), 45), ((19, 5, 77), 14)],
588+
cached_k,
589+
cached_v,
590+
)
585591

586592
operation_seq = [
587593
[(6, 1), (11, 1), (13, 1), (9, 1)],
@@ -607,6 +613,57 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
607613
assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences"
608614

609615

616+
@tvm.testing.requires_gpu
617+
@tvm.testing.requires_cuda
618+
def test_paged_attention_kv_cache_unlimited_depth(kv_cache_and_config):
619+
kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
620+
if support_sliding_window and rope_mode == RopeMode.NORMAL:
621+
# Normal RoPE mode under sliding window settings is not supported.
622+
return
623+
fclear(kv_cache)
624+
625+
cached_k = {}
626+
cached_v = {}
627+
apply_attention(kv_cache, rope_mode, [(0, 30)], cached_k, cached_v)
628+
# Fork existing sequences.
629+
apply_attention(kv_cache, rope_mode, [((1, 0, -1), 15)], cached_k, cached_v)
630+
apply_attention(kv_cache, rope_mode, [((2, 1, -1), 5)], cached_k, cached_v)
631+
apply_attention(kv_cache, rope_mode, [((3, 2, -1), 20)], cached_k, cached_v)
632+
apply_attention(kv_cache, rope_mode, [((4, 3, -1), 26)], cached_k, cached_v)
633+
apply_attention(kv_cache, rope_mode, [((5, 3, -1), 18)], cached_k, cached_v)
634+
apply_attention(kv_cache, rope_mode, [((6, 5, -1), 22)], cached_k, cached_v)
635+
apply_attention(kv_cache, rope_mode, [((7, 5, -1), 12)], cached_k, cached_v)
636+
apply_attention(kv_cache, rope_mode, [((8, 7, -1), 29)], cached_k, cached_v)
637+
apply_attention(kv_cache, rope_mode, [((9, 7, -1), 9)], cached_k, cached_v)
638+
apply_attention(kv_cache, rope_mode, [((10, 9, -1), 31)], cached_k, cached_v)
639+
apply_attention(kv_cache, rope_mode, [((11, 9, -1), 4)], cached_k, cached_v)
640+
# 0 <- 1 <- 2 <- 3 <- 5 <- 7 <- 9 <- 11
641+
# | | | |
642+
# 4 6 8 10
643+
# Decode.
644+
operation_seq = [
645+
[(3, 1), (6, 1), (9, 1)],
646+
[(4, 1), (8, 1), (10, 1)],
647+
[(5, 1), (7, 1), (11, 1)],
648+
]
649+
for batch in operation_seq:
650+
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
651+
652+
num_sequence = 12
653+
for i in range(num_sequence):
654+
fremove_sequence(kv_cache, i)
655+
cached_k.pop(i)
656+
cached_v.pop(i)
657+
verify_cached_kv(
658+
kv_cache,
659+
seq_ids=list(range(i + 1, num_sequence)),
660+
expected_k=cached_k,
661+
expected_v=cached_v,
662+
)
663+
664+
assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences"
665+
666+
610667
@tvm.testing.requires_gpu
611668
@tvm.testing.requires_cuda
612669
def test_paged_attention_kv_cache_popn(kv_cache_and_config):
@@ -2526,6 +2583,7 @@ def compact_kv_copy(
25262583
for head_dim, dtype, rope_mode, support_sliding_window in itertools.product(
25272584
HEAD_DIMS, DTYPES, ROPE_MODES, SUPPORT_SLIDING_WINDOW
25282585
):
2586+
print(head_dim, dtype, rope_mode, support_sliding_window)
25292587
set_global_func(head_dim, dtype)
25302588
cache = create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window)
25312589
cache_and_config = (cache, rope_mode, support_sliding_window)
@@ -2535,3 +2593,4 @@ def compact_kv_copy(
25352593
test_paged_attention_kv_cache_popn(cache_and_config)
25362594
test_paged_attention_kv_cache_sliding_window(cache_and_config)
25372595
test_paged_attention_kv_cache_tree_attn(cache_and_config)
2596+
test_paged_attention_kv_cache_unlimited_depth(cache_and_config)

0 commit comments

Comments
 (0)